"docs/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "1cb210c8326fc09ac0d06edc8cee96a38ae39550"
Unverified Commit b372b3c7 authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVStore] Add fast-pull for kvstore (#1647)

* add fast-pull

* update

* add fast-pull

* update

* update

* update

* update test

* update

* update

* update

* update

* update

* update

* update

* update

* add omp

* update

* update
parent e8a56dc1
...@@ -1032,49 +1032,58 @@ class KVClient(object): ...@@ -1032,49 +1032,58 @@ class KVClient(object):
tensor tensor
a data tensor with the same row size of id_tensor. a data tensor with the same row size of id_tensor.
""" """
#TODO(chao) : add C++ rpc interface and add fast pull
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, 'name cannot be empty.'
assert F.ndim(id_tensor) == 1, 'ID must be a vector.' assert F.ndim(id_tensor) == 1, 'ID must be a vector.'
# partition data if self._pull_handlers[name] is default_pull_handler: # Use fast-pull
machine_id = self._part_policy[name].to_partid(id_tensor) part_id = self._part_policy[name].to_partid(id_tensor)
# sort index by machine id return rpc.fast_pull(name, id_tensor, part_id, KVSTORE_PULL,
sorted_id = F.tensor(np.argsort(F.asnumpy(machine_id))) self._machine_count,
back_sorted_id = F.tensor(np.argsort(F.asnumpy(sorted_id))) self._group_count,
id_tensor = id_tensor[sorted_id] self._machine_id,
machine, count = np.unique(F.asnumpy(machine_id), return_counts=True) self._client_id,
# pull data from server by order self._data_store[name],
start = 0 self._part_policy[name])
pull_count = 0 else:
local_id = None # partition data
for idx, machine_idx in enumerate(machine): machine_id = self._part_policy[name].to_partid(id_tensor)
end = start + count[idx] # sort index by machine id
if start == end: # No data for target machine sorted_id = F.tensor(np.argsort(F.asnumpy(machine_id)))
continue back_sorted_id = F.tensor(np.argsort(F.asnumpy(sorted_id)))
partial_id = id_tensor[start:end] id_tensor = id_tensor[sorted_id]
if machine_idx == self._machine_id: # local pull machine, count = np.unique(F.asnumpy(machine_id), return_counts=True)
# Note that DO NOT pull local data right now because we can overlap # pull data from server by order
# communication-local_pull here start = 0
local_id = self._part_policy[name].to_local(partial_id) pull_count = 0
else: # pull data from remote server local_id = None
request = PullRequest(name, partial_id) for idx, machine_idx in enumerate(machine):
rpc.send_request_to_machine(machine_idx, request) end = start + count[idx]
pull_count += 1 if start == end: # No data for target machine
start += count[idx] continue
# recv response partial_id = id_tensor[start:end]
response_list = [] if machine_idx == self._machine_id: # local pull
if local_id is not None: # local pull # Note that DO NOT pull local data right now because we can overlap
local_data = self._pull_handlers[name](self._data_store, name, local_id) # communication-local_pull here
server_id = self._main_server_id local_id = self._part_policy[name].to_local(partial_id)
local_response = PullResponse(server_id, local_data) else: # pull data from remote server
response_list.append(local_response) request = PullRequest(name, partial_id)
# wait response from remote server nodes rpc.send_request_to_machine(machine_idx, request)
for _ in range(pull_count): pull_count += 1
remote_response = rpc.recv_response() start += count[idx]
response_list.append(remote_response) # recv response
# sort response by server_id and concat tensor response_list = []
response_list.sort(key=self._take_id) if local_id is not None: # local pull
data_tensor = F.cat(seq=[response.data_tensor for response in response_list], dim=0) local_data = self._pull_handlers[name](self._data_store, name, local_id)
return data_tensor[back_sorted_id] # return data with original index order server_id = self._main_server_id
local_response = PullResponse(server_id, local_data)
response_list.append(local_response)
# wait response from remote server nodes
for _ in range(pull_count):
remote_response = rpc.recv_response()
response_list.append(remote_response)
# sort response by server_id and concat tensor
response_list.sort(key=self._take_id)
data_tensor = F.cat(seq=[response.data_tensor for response in response_list], dim=0)
return data_tensor[back_sorted_id] # return data with original index order
def _take_id(self, elem): def _take_id(self, elem):
"""Used by sort response list """Used by sort response list
......
...@@ -14,7 +14,7 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \ ...@@ -14,7 +14,7 @@ __all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'receiver_wait', 'add_receiver_addr', 'sender_connect', 'read_ip_config', \ 'receiver_wait', 'add_receiver_addr', 'sender_connect', 'read_ip_config', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \ 'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \ 'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine'] 'send_request_to_machine', 'remote_call_to_machine', 'fast_pull']
REQUEST_CLASS_TO_SERVICE_ID = {} REQUEST_CLASS_TO_SERVICE_ID = {}
RESPONSE_CLASS_TO_SERVICE_ID = {} RESPONSE_CLASS_TO_SERVICE_ID = {}
...@@ -844,6 +844,55 @@ def finalize_server(): ...@@ -844,6 +844,55 @@ def finalize_server():
finalize_receiver() finalize_receiver()
print("Server (%d) shutdown." % get_rank()) print("Server (%d) shutdown." % get_rank())
def fast_pull(name, id_tensor, part_id, service_id,
machine_count, group_count, machine_id,
client_id, local_data, policy):
"""Fast-pull api used by kvstore.
Parameters
----------
name : str
data name
id_tensor : tensor
data ID
part_id : tensor
partition ID of id_tensor
service_id : int
service_id of pull request
machine_count : int
total number of machine
group_count : int
total number of server inside machine
machine_id : int
current machine ID
client_id : int
current client ID
local_data : tensor
local data tensor
policy : PartitionPolicy
store the partition information
"""
msg_seq = incr_msg_seq()
pickle_data = bytearray(pickle.dumps(([0], [name])))
global_id = _CAPI_DGLRPCGetGlobalIDFromLocalPartition(F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(part_id),
machine_id)
global_id = F.zerocopy_from_dgl_ndarray(global_id)
g2l_id = policy.to_local(global_id)
res_tensor = _CAPI_DGLRPCFastPull(name,
int(machine_id),
int(machine_count),
int(group_count),
int(client_id),
int(service_id),
int(msg_seq),
pickle_data,
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(part_id),
F.zerocopy_to_dgl_ndarray(g2l_id),
F.zerocopy_to_dgl_ndarray(local_data))
return F.zerocopy_from_dgl_ndarray(res_tensor)
############### Some basic services will be defined here ############# ############### Some basic services will be defined here #############
CLIENT_REGISTER = 22451 CLIENT_REGISTER = 22451
......
...@@ -76,14 +76,24 @@ def start_server(server_id, ip_config, num_clients, server_state, \ ...@@ -76,14 +76,24 @@ def start_server(server_id, ip_config, num_clients, server_state, \
rpc.send_response(client_id, register_res) rpc.send_response(client_id, register_res)
# main service loop # main service loop
while True: while True:
req, client_id = rpc.recv_request() try:
res = req.process_request(server_state) req, client_id = rpc.recv_request()
if res is not None: res = req.process_request(server_state)
if isinstance(res, list): if res is not None:
for response in res: if isinstance(res, list):
target_id, res_data = response for response in res:
rpc.send_response(target_id, res_data) target_id, res_data = response
elif isinstance(res, str) and res == 'exit': rpc.send_response(target_id, res_data)
break # break the loop and exit server elif isinstance(res, str) and res == 'exit':
else: break # break the loop and exit server
rpc.send_response(client_id, res) else:
rpc.send_response(client_id, res)
except KeyboardInterrupt:
print("Exit kvserver!")
rpc.finalize_sender()
rpc.finalize_receiver()
except:
print("Error on kvserver!")
rpc.finalize_sender()
rpc.finalize_receiver()
raise
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include <dgl/array.h>
#include <dgl/random.h>
#include <dgl/zerocopy_serializer.h> #include <dgl/zerocopy_serializer.h>
#include "../c_api_common.h" #include "../c_api_common.h"
...@@ -34,6 +36,9 @@ RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) { ...@@ -34,6 +36,9 @@ RPCStatus SendRPCMessage(const RPCMessage& msg, const int32_t target_id) {
for (auto ptr : zc_write_strm.buffer_list()) { for (auto ptr : zc_write_strm.buffer_list()) {
network::Message ndarray_data_msg; network::Message ndarray_data_msg;
ndarray_data_msg.data = reinterpret_cast<char*>(ptr.data); ndarray_data_msg.data = reinterpret_cast<char*>(ptr.data);
if (ptr.size == 0) {
LOG(FATAL) << "Cannot send a empty NDArray.";
}
ndarray_data_msg.size = ptr.size; ndarray_data_msg.size = ptr.size;
NDArray tensor = ptr.tensor; NDArray tensor = ptr.tensor;
ndarray_data_msg.deallocator = [tensor](network::Message*) {}; ndarray_data_msg.deallocator = [tensor](network::Message*) {};
...@@ -295,5 +300,130 @@ DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState") ...@@ -295,5 +300,130 @@ DGL_REGISTER_GLOBAL("distributed.server_state._CAPI_DGLRPCGetServerState")
*rv = st; *rv = st;
}); });
//////////////////////////// KVStore ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCGetGlobalIDFromLocalPartition")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray ID = args[0];
NDArray part_id = args[1];
int local_machine_id = args[2];
int64_t* ID_data = static_cast<int64_t*>(ID->data);
int64_t* part_id_data = static_cast<int64_t*>(part_id->data);
int64_t ID_size = ID.GetSize() / sizeof(int64_t);
std::vector<int64_t> global_id;
for (int64_t i = 0; i < ID_size; ++i) {
if (part_id_data[i] == local_machine_id) {
global_id.push_back(ID_data[i]);
}
}
NDArray res_tensor = dgl::aten::VecToIdArray<int64_t>(global_id);
*rv = res_tensor;
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCFastPull")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
// Input
std::string name = args[0];
int local_machine_id = args[1];
int machine_count = args[2];
int group_count = args[3];
int client_id = args[4];
int service_id = args[5];
int msg_seq = args[6];
std::string pickle_data = args[7];
NDArray ID = args[8];
NDArray part_id = args[9];
NDArray local_id = args[10];
NDArray local_data = args[11];
// Data
dgl_id_t ID_size = ID.GetSize() / sizeof(dgl_id_t);
dgl_id_t* ID_data = static_cast<dgl_id_t*>(ID->data);
dgl_id_t* part_id_data = static_cast<dgl_id_t*>(part_id->data);
dgl_id_t* local_id_data = static_cast<dgl_id_t*>(local_id->data);
char* local_data_char = static_cast<char*>(local_data->data);
std::vector<dgl_id_t> local_ids;
std::vector<dgl_id_t> local_ids_orginal;
std::vector<int64_t> local_data_shape;
std::vector<std::vector<dgl_id_t> > remote_ids(machine_count);
std::vector<std::vector<dgl_id_t> > remote_ids_original(machine_count);
// Get row size (in bytes)
int row_size = 1;
for (int i = 0; i < local_data->ndim; ++i) {
local_data_shape.push_back(local_data->shape[i]);
if (i != 0) {
row_size *= local_data->shape[i];
}
}
row_size *= (local_data->dtype.bits / 8);
size_t data_size = local_data.GetSize();
CHECK_GT(local_data_shape.size(), 0);
CHECK_EQ(row_size * local_data_shape[0], data_size);
// Get local id (used in local machine) and
// remote id (send to remote machine)
dgl_id_t idx = 0;
for (dgl_id_t i = 0; i < ID_size; ++i) {
dgl_id_t p_id = part_id_data[i];
if (p_id == local_machine_id) {
dgl_id_t l_id = local_id_data[idx++];
CHECK_LT(l_id, local_data_shape[0]);
CHECK_GE(l_id, 0);
local_ids.push_back(l_id);
local_ids_orginal.push_back(i);
} else {
CHECK_LT(p_id, machine_count) << "Invalid partition ID.";
dgl_id_t id = ID_data[i];
remote_ids[p_id].push_back(id);
remote_ids_original[p_id].push_back(i);
}
}
// Send remote id
int msg_count = 0;
for (int i = 0; i < remote_ids.size(); ++i) {
if (remote_ids[i].size() != 0) {
RPCMessage msg;
msg.service_id = service_id;
msg.msg_seq = msg_seq;
msg.client_id = client_id;
int lower = i*group_count;
int upper = (i+1)*group_count;
msg.server_id = dgl::RandomEngine::ThreadLocal()->RandInt(lower, upper);
msg.data = pickle_data;
NDArray tensor = dgl::aten::VecToIdArray<dgl_id_t>(remote_ids[i]);
msg.tensors.push_back(tensor);
SendRPCMessage(msg, msg.server_id);
msg_count++;
}
}
local_data_shape[0] = ID_size;
NDArray res_tensor = NDArray::Empty(local_data_shape,
local_data->dtype,
DLContext{kDLCPU, 0});
char* return_data = static_cast<char*>(res_tensor->data);
// Copy local data
#pragma omp parallel for
for (int64_t i = 0; i < local_ids.size(); ++i) {
CHECK_GE(ID_size*row_size, local_ids_orginal[i]*row_size+row_size);
CHECK_GE(data_size, local_ids[i] * row_size + row_size);
CHECK_GE(local_ids[i], 0);
memcpy(return_data + local_ids_orginal[i] * row_size,
local_data_char + local_ids[i] * row_size,
row_size);
}
// Recv remote message
for (int i = 0; i < msg_count; ++i) {
RPCMessage msg;
RecvRPCMessage(&msg, 0);
int part_id = msg.server_id / group_count;
char* data_char = static_cast<char*>(msg.tensors[0]->data);
dgl_id_t id_size = remote_ids[part_id].size();
for (size_t n = 0; n < id_size; ++n) {
memcpy(return_data + remote_ids_original[part_id][n] * row_size,
data_char + n * row_size,
row_size);
}
}
*rv = res_tensor;
});
} // namespace rpc } // namespace rpc
} // namespace dgl } // namespace dgl
...@@ -88,7 +88,7 @@ def init_zero_func(shape, dtype): ...@@ -88,7 +88,7 @@ def init_zero_func(shape, dtype):
return F.zeros(shape, dtype, F.cpu()) return F.zeros(shape, dtype, F.cpu())
def udf_push(target, name, id_tensor, data_tensor): def udf_push(target, name, id_tensor, data_tensor):
target[name] = F.scatter_row(target[name], id_tensor, data_tensor*data_tensor) target[name][id_tensor] = data_tensor * data_tensor
@unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet') @unittest.skipIf(os.name == 'nt' or os.getenv('DGLBACKEND') == 'tensorflow', reason='Do not support windows and TF yet')
def test_partition_policy(): def test_partition_policy():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment