Unverified Commit 338f24cf authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVstore] Fast-pull (#1446)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint
parent ad222fb9
...@@ -5,6 +5,7 @@ from ..network import _network_wait, _add_receiver_addr ...@@ -5,6 +5,7 @@ from ..network import _network_wait, _add_receiver_addr
from ..network import _receiver_wait, _sender_connect from ..network import _receiver_wait, _sender_connect
from ..network import _send_kv_msg, _recv_kv_msg from ..network import _send_kv_msg, _recv_kv_msg
from ..network import _clear_kv_msg from ..network import _clear_kv_msg
from ..network import _fast_pull
from ..network import KVMsgType, KVStoreMsg from ..network import KVMsgType, KVStoreMsg
from .. import backend as F from .. import backend as F
...@@ -146,6 +147,11 @@ class KVServer(object): ...@@ -146,6 +147,11 @@ class KVServer(object):
self._open_file_list = [] self._open_file_list = []
# record for total message count # record for total message count
self._msg_count = 0 self._msg_count = 0
# user-defined push handler
self._udf_push_handler = None
self._udf_push_param = None
# user-defined pull handler
self._udf_pull_handler = None
def __del__(self): def __del__(self):
...@@ -317,6 +323,8 @@ class KVServer(object): ...@@ -317,6 +323,8 @@ class KVServer(object):
# Get connected with all client nodes # Get connected with all client nodes
_receiver_wait(self._receiver, self._ip, self._port, self._client_count) _receiver_wait(self._receiver, self._ip, self._port, self._client_count)
print("%d clients connected!" % self._client_count)
# recv client address information # recv client address information
addr_list = [] addr_list = []
for i in range(self._client_count): for i in range(self._client_count):
...@@ -378,14 +386,20 @@ class KVServer(object): ...@@ -378,14 +386,20 @@ class KVServer(object):
local_id = self._data_store[msg.name+'-g2l-'][msg.id] local_id = self._data_store[msg.name+'-g2l-'][msg.id]
else: else:
local_id = msg.id local_id = msg.id
self._push_handler(msg.name+'-data-', local_id, msg.data, self._data_store) if self._udf_push_handler is not None:
self._udf_push_handler(msg.name+'-data-', local_id, msg.data, self._data_store, self._udf_push_param)
else:
self._default_push_handler(msg.name+'-data-', local_id, msg.data, self._data_store)
# Pull message # Pull message
elif msg.type == KVMsgType.PULL: elif msg.type == KVMsgType.PULL:
if (msg.name+'-g2l-' in self._has_data) == True: if (msg.name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[msg.name+'-g2l-'][msg.id] local_id = self._data_store[msg.name+'-g2l-'][msg.id]
else: else:
local_id = msg.id local_id = msg.id
res_tensor = self._pull_handler(msg.name+'-data-', local_id, self._data_store) if self._udf_pull_handler is not None:
res_tensor = self._udf_pull_handler(msg.name+'-data-', local_id, self._data_store)
else:
res_tensor = self._default_pull_handler(msg.name+'-data-', local_id, self._data_store)
back_msg = KVStoreMsg( back_msg = KVStoreMsg(
type=KVMsgType.PULL_BACK, type=KVMsgType.PULL_BACK,
rank=self._server_id, rank=self._server_id,
...@@ -500,7 +514,7 @@ class KVServer(object): ...@@ -500,7 +514,7 @@ class KVServer(object):
return data_shape return data_shape
def _push_handler(self, name, ID, data, target): def _default_push_handler(self, name, ID, data, target):
"""Default handler for PUSH message. """Default handler for PUSH message.
On default, _push_handler perform update operation for the tensor. On default, _push_handler perform update operation for the tensor.
...@@ -519,7 +533,7 @@ class KVServer(object): ...@@ -519,7 +533,7 @@ class KVServer(object):
target[name][ID] = data target[name][ID] = data
def _pull_handler(self, name, ID, target): def _default_pull_handler(self, name, ID, target):
"""Default handler for PULL operation. """Default handler for PULL operation.
On default, _pull_handler perform get operation for the tensor. On default, _pull_handler perform get operation for the tensor.
...@@ -582,6 +596,7 @@ class KVClient(object): ...@@ -582,6 +596,7 @@ class KVClient(object):
self._server_namebook = server_namebook self._server_namebook = server_namebook
self._server_count = len(server_namebook) self._server_count = len(server_namebook)
self._group_count = server_namebook[0][3] self._group_count = server_namebook[0][3]
self._machine_count = int(self._server_count / self._group_count)
# client ID will be assign by server after connecting to server # client ID will be assign by server after connecting to server
self._client_id = -1 self._client_id = -1
# Get local machine id via server_namebook # Get local machine id via server_namebook
...@@ -593,6 +608,11 @@ class KVClient(object): ...@@ -593,6 +608,11 @@ class KVClient(object):
self._open_file_list = [] self._open_file_list = []
# Gargage_collection # Gargage_collection
self._garbage_msg = [] self._garbage_msg = []
# User-defined pull handler
self._udf_pull_handler = None
# User-defined push handler
self._udf_push_handler = None
self._udf_push_param = None
# Used load-balance # Used load-balance
random.seed(time.time()) random.seed(time.time())
...@@ -812,7 +832,10 @@ class KVClient(object): ...@@ -812,7 +832,10 @@ class KVClient(object):
start += count[idx] start += count[idx]
if local_id is not None: # local push if local_id is not None: # local push
self._push_handler(name+'-data-', local_id, local_data, self._data_store) if self._udf_push_handler is not None:
self._udf_push_handler(name+'-data-', local_id, local_data, self._data_store, self._udf_push_param)
else:
self._default_push_handler(name+'-data-', local_id, local_data, self._data_store)
def pull(self, name, id_tensor): def pull(self, name, id_tensor):
...@@ -833,73 +856,88 @@ class KVClient(object): ...@@ -833,73 +856,88 @@ class KVClient(object):
assert len(name) > 0, 'name cannot be empty.' assert len(name) > 0, 'name cannot be empty.'
assert F.ndim(id_tensor) == 1, 'ID must be a vector.' assert F.ndim(id_tensor) == 1, 'ID must be a vector.'
for msg in self._garbage_msg: if self._udf_pull_handler is None: # Use fast-pull
_clear_kv_msg(msg) g2l = None
self._garbage_msg = [] if name+'-g2l-' in self._data_store:
g2l = self._data_store[name+'-g2l-']
# partition data return _fast_pull(name, id_tensor,
machine_id = self._data_store[name+'-part-'][id_tensor] self._machine_count,
# sort index by machine id self._group_count,
sorted_id = F.tensor(np.argsort(F.asnumpy(machine_id))) self._machine_id,
back_sorted_id = F.tensor(np.argsort(F.asnumpy(sorted_id))) self._client_id,
id_tensor = id_tensor[sorted_id] self._data_store[name+'-part-'],
machine, count = np.unique(F.asnumpy(machine_id), return_counts=True) g2l,
# pull data from server by order self._data_store[name+'-data-'],
start = 0 self._sender,
pull_count = 0 self._receiver)
local_id = None else:
for idx in range(len(machine)): for msg in self._garbage_msg:
end = start + count[idx] _clear_kv_msg(msg)
if start == end: # No data for target machine self._garbage_msg = []
continue
partial_id = id_tensor[start:end] # partition data
if machine[idx] == self._machine_id: # local pull machine_id = self._data_store[name+'-part-'][id_tensor]
# Note that DO NOT pull local data right now because we can overlap # sort index by machine id
# communication-local_pull here sorted_id = F.tensor(np.argsort(F.asnumpy(machine_id)))
if (name+'-g2l-' in self._has_data) == True: back_sorted_id = F.tensor(np.argsort(F.asnumpy(sorted_id)))
local_id = self._data_store[name+'-g2l-'][partial_id] id_tensor = id_tensor[sorted_id]
else: machine, count = np.unique(F.asnumpy(machine_id), return_counts=True)
local_id = partial_id # pull data from server by order
else: # pull data from remote server start = 0
msg = KVStoreMsg( pull_count = 0
type=KVMsgType.PULL, local_id = None
rank=self._client_id, for idx in range(len(machine)):
end = start + count[idx]
if start == end: # No data for target machine
continue
partial_id = id_tensor[start:end]
if machine[idx] == self._machine_id: # local pull
# Note that DO NOT pull local data right now because we can overlap
# communication-local_pull here
if (name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[name+'-g2l-'][partial_id]
else:
local_id = partial_id
else: # pull data from remote server
msg = KVStoreMsg(
type=KVMsgType.PULL,
rank=self._client_id,
name=name,
id=partial_id,
data=None,
c_ptr=None)
# randomly select a server node in target machine for load-balance
s_id = random.randint(machine[idx]*self._group_count, (machine[idx]+1)*self._group_count-1)
_send_kv_msg(self._sender, msg, s_id)
pull_count += 1
start += count[idx]
msg_list = []
if local_id is not None: # local pull
local_data = self._udf_pull_handler(name+'-data-', local_id, self._data_store)
s_id = random.randint(self._machine_id*self._group_count, (self._machine_id+1)*self._group_count-1)
local_msg = KVStoreMsg(
type=KVMsgType.PULL_BACK,
rank=s_id,
name=name, name=name,
id=partial_id, id=None,
data=None, data=local_data,
c_ptr=None) c_ptr=None)
# randomly select a server node in target machine for load-balance msg_list.append(local_msg)
s_id = random.randint(machine[idx]*self._group_count, (machine[idx]+1)*self._group_count-1) self._garbage_msg.append(local_msg)
_send_kv_msg(self._sender, msg, s_id)
pull_count += 1
start += count[idx]
msg_list = []
if local_id is not None: # local pull
local_data = self._pull_handler(name+'-data-', local_id, self._data_store)
s_id = random.randint(self._machine_id*self._group_count, (self._machine_id+1)*self._group_count-1)
local_msg = KVStoreMsg(
type=KVMsgType.PULL_BACK,
rank=s_id,
name=name,
id=None,
data=local_data,
c_ptr=None)
msg_list.append(local_msg)
self._garbage_msg.append(local_msg)
# wait message from server nodes # wait message from server nodes
for idx in range(pull_count): for idx in range(pull_count):
remote_msg = _recv_kv_msg(self._receiver) remote_msg = _recv_kv_msg(self._receiver)
msg_list.append(remote_msg) msg_list.append(remote_msg)
self._garbage_msg.append(remote_msg) self._garbage_msg.append(remote_msg)
# sort msg by server id and merge tensor together # sort msg by server id and merge tensor together
msg_list.sort(key=self._takeId) msg_list.sort(key=self._takeId)
data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0) data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0)
return data_tensor[back_sorted_id] # return data with original index order return data_tensor[back_sorted_id] # return data with original index order
def barrier(self): def barrier(self):
...@@ -1082,7 +1120,7 @@ class KVClient(object): ...@@ -1082,7 +1120,7 @@ class KVClient(object):
return elem.rank return elem.rank
def _push_handler(self, name, ID, data, target): def _default_push_handler(self, name, ID, data, target):
"""Default handler for PUSH message. """Default handler for PUSH message.
On default, _push_handler perform update operation for the tensor. On default, _push_handler perform update operation for the tensor.
...@@ -1099,26 +1137,4 @@ class KVClient(object): ...@@ -1099,26 +1137,4 @@ class KVClient(object):
self._data_store self._data_store
""" """
target[name][ID] = data target[name][ID] = data
def _pull_handler(self, name, ID, target):
"""Default handler for PULL operation.
On default, _pull_handler perform get operation for the tensor.
Parameters
----------
name : str
data name
ID : tensor (mx.ndarray or torch.tensor)
a vector storing the ID list.
target : dict of data
self._data_store
Return
------
tensor
a tensor with the same row size of ID.
"""
return target[name][ID]
\ No newline at end of file
...@@ -327,3 +327,56 @@ def _clear_kv_msg(msg): ...@@ -327,3 +327,56 @@ def _clear_kv_msg(msg):
F.sync() F.sync()
if msg.c_ptr is not None: if msg.c_ptr is not None:
_CAPI_DeleteKVMsg(msg.c_ptr) _CAPI_DeleteKVMsg(msg.c_ptr)
def _fast_pull(name, id_tensor,
machine_count, group_count, machine_id, client_id,
partition_book, g2l, local_data,
sender, receiver):
""" Pull message
Parameters
----------
name : str
data name string
id_tensor : tensor
tensor of ID
machine_count : int
count of total machine
group_count : int
count of server group
machine_id : int
current machine id
client_id : int
current client ID
partition_book : tensor
tensor of partition book
g2l : tensor
tensor of global2local
local_data : tensor
tensor of local shared data
sender : ctypes.c_void_p
C Sender handle
receiver : ctypes.c_void_p
C Receiver handle
Return
------
tensor
target tensor
"""
if g2l is not None:
res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id,
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data),
sender, receiver, 'has_g2l',
F.zerocopy_to_dgl_ndarray(g2l))
else:
res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id,
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data),
sender, receiver, 'no_g2l')
return F.zerocopy_from_dgl_ndarray(res_tensor)
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
*/ */
#include "./network.h" #include "./network.h"
#include <stdlib.h>
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
...@@ -21,6 +23,8 @@ ...@@ -21,6 +23,8 @@
using dgl::network::StringPrintf; using dgl::network::StringPrintf;
using namespace dgl::runtime; using namespace dgl::runtime;
const bool AUTO_FREE = true;
namespace dgl { namespace dgl {
namespace network { namespace network {
...@@ -34,7 +38,8 @@ static void NaiveDeleter(DLManagedTensor* managed_tensor) { ...@@ -34,7 +38,8 @@ static void NaiveDeleter(DLManagedTensor* managed_tensor) {
NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape, NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape,
DLDataType dtype, DLDataType dtype,
DLContext ctx, DLContext ctx,
void* raw) { void* raw,
bool auto_free) {
DLTensor tensor; DLTensor tensor;
tensor.ctx = ctx; tensor.ctx = ctx;
tensor.ndim = static_cast<int>(shape.size()); tensor.ndim = static_cast<int>(shape.size());
...@@ -53,7 +58,9 @@ NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape, ...@@ -53,7 +58,9 @@ NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape,
tensor.data = raw; tensor.data = raw;
DLManagedTensor *managed_tensor = new DLManagedTensor(); DLManagedTensor *managed_tensor = new DLManagedTensor();
managed_tensor->dl_tensor = tensor; managed_tensor->dl_tensor = tensor;
managed_tensor->deleter = NaiveDeleter; if (auto_free) {
managed_tensor->deleter = NaiveDeleter;
}
return NDArray::FromDLPack(managed_tensor); return NDArray::FromDLPack(managed_tensor);
} }
...@@ -382,7 +389,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -382,7 +389,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[1]}, {meta.data_shape_[1]},
DLDataType{kDLInt, 64, 1}, DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0}, DLContext{kDLCPU, 0},
array_0.data); array_0.data,
AUTO_FREE);
// edge_mapping // edge_mapping
Message array_1; Message array_1;
CHECK_EQ(receiver->RecvFrom(&array_1, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&array_1, send_id), REMOVE_SUCCESS);
...@@ -391,7 +399,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -391,7 +399,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[3]}, {meta.data_shape_[3]},
DLDataType{kDLInt, 64, 1}, DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0}, DLContext{kDLCPU, 0},
array_1.data); array_1.data,
AUTO_FREE);
// layer_offset // layer_offset
Message array_2; Message array_2;
CHECK_EQ(receiver->RecvFrom(&array_2, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&array_2, send_id), REMOVE_SUCCESS);
...@@ -400,7 +409,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -400,7 +409,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[5]}, {meta.data_shape_[5]},
DLDataType{kDLInt, 64, 1}, DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0}, DLContext{kDLCPU, 0},
array_2.data); array_2.data,
AUTO_FREE);
// flow_offset // flow_offset
Message array_3; Message array_3;
CHECK_EQ(receiver->RecvFrom(&array_3, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&array_3, send_id), REMOVE_SUCCESS);
...@@ -409,7 +419,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -409,7 +419,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[7]}, {meta.data_shape_[7]},
DLDataType{kDLInt, 64, 1}, DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0}, DLContext{kDLCPU, 0},
array_3.data); array_3.data,
AUTO_FREE);
// CSR indptr // CSR indptr
Message array_4; Message array_4;
CHECK_EQ(receiver->RecvFrom(&array_4, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&array_4, send_id), REMOVE_SUCCESS);
...@@ -418,7 +429,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -418,7 +429,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[9]}, {meta.data_shape_[9]},
DLDataType{kDLInt, 64, 1}, DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0}, DLContext{kDLCPU, 0},
array_4.data); array_4.data,
AUTO_FREE);
// CSR indice // CSR indice
Message array_5; Message array_5;
CHECK_EQ(receiver->RecvFrom(&array_5, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&array_5, send_id), REMOVE_SUCCESS);
...@@ -427,7 +439,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -427,7 +439,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[11]}, {meta.data_shape_[11]},
DLDataType{kDLInt, 64, 1}, DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0}, DLContext{kDLCPU, 0},
array_5.data); array_5.data,
AUTO_FREE);
// CSR edge_ids // CSR edge_ids
Message array_6; Message array_6;
CHECK_EQ(receiver->RecvFrom(&array_6, send_id), REMOVE_SUCCESS); CHECK_EQ(receiver->RecvFrom(&array_6, send_id), REMOVE_SUCCESS);
...@@ -436,7 +449,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -436,7 +449,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[13]}, {meta.data_shape_[13]},
DLDataType{kDLInt, 64, 1}, DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0}, DLContext{kDLCPU, 0},
array_6.data); array_6.data,
AUTO_FREE);
// Create CSR // Create CSR
CSRPtr csr(new CSR(indptr, indice, edge_ids)); CSRPtr csr(new CSR(indptr, indice, edge_ids));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr)); nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr));
...@@ -452,6 +466,106 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow") ...@@ -452,6 +466,106 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
////////////////////////// Distributed KVStore Components //////////////////////////////// ////////////////////////// Distributed KVStore Components ////////////////////////////////
static void send_kv_message(network::Sender* sender,
KVStoreMsg* kv_msg,
int recv_id,
bool auto_free) {
int64_t kv_size = 0;
char* kv_data = kv_msg->Serialize(&kv_size);
// Send kv_data
Message send_kv_msg;
send_kv_msg.data = kv_data;
send_kv_msg.size = kv_size;
if (auto_free) {
send_kv_msg.deallocator = DefaultMessageDeleter;
}
CHECK_EQ(sender->Send(send_kv_msg, recv_id), ADD_SUCCESS);
if (kv_msg->msg_type != kFinalMsg &&
kv_msg->msg_type != kBarrierMsg &&
kv_msg->msg_type != kIPIDMsg) {
// Send ArrayMeta
ArrayMeta meta(kv_msg->msg_type);
meta.AddArray(kv_msg->id);
if (kv_msg->msg_type != kPullMsg) {
meta.AddArray(kv_msg->data);
}
int64_t meta_size = 0;
char* meta_data = meta.Serialize(&meta_size);
Message send_meta_msg;
send_meta_msg.data = meta_data;
send_meta_msg.size = meta_size;
if (auto_free) {
send_meta_msg.deallocator = DefaultMessageDeleter;
}
CHECK_EQ(sender->Send(send_meta_msg, recv_id), ADD_SUCCESS);
// Send ID NDArray
Message send_id_msg;
send_id_msg.data = static_cast<char*>(kv_msg->id->data);
send_id_msg.size = kv_msg->id.GetSize();
NDArray id = kv_msg->id;
send_id_msg.deallocator = [id](Message*) {};
CHECK_EQ(sender->Send(send_id_msg, recv_id), ADD_SUCCESS);
// Send data NDArray
if (kv_msg->msg_type != kPullMsg) {
Message send_data_msg;
send_data_msg.data = static_cast<char*>(kv_msg->data->data);
send_data_msg.size = kv_msg->data.GetSize();
NDArray data = kv_msg->data;
if (auto_free) {
send_data_msg.deallocator = [data](Message*) {};
}
CHECK_EQ(sender->Send(send_data_msg, recv_id), ADD_SUCCESS);
}
}
}
static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
KVStoreMsg *kv_msg = new KVStoreMsg();
// Recv kv_Msg
Message recv_kv_msg;
int send_id;
CHECK_EQ(receiver->Recv(&recv_kv_msg, &send_id), REMOVE_SUCCESS);
kv_msg->Deserialize(recv_kv_msg.data, recv_kv_msg.size);
recv_kv_msg.deallocator(&recv_kv_msg);
if (kv_msg->msg_type == kFinalMsg ||
kv_msg->msg_type == kBarrierMsg ||
kv_msg->msg_type == kIPIDMsg) {
return kv_msg;
}
// Recv ArrayMeta
Message recv_meta_msg;
CHECK_EQ(receiver->RecvFrom(&recv_meta_msg, send_id), REMOVE_SUCCESS);
ArrayMeta meta(recv_meta_msg.data, recv_meta_msg.size);
recv_meta_msg.deallocator(&recv_meta_msg);
// Recv ID NDArray
Message recv_id_msg;
CHECK_EQ(receiver->RecvFrom(&recv_id_msg, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[0], 1);
kv_msg->id = CreateNDArrayFromRaw(
{meta.data_shape_[1]},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
recv_id_msg.data,
AUTO_FREE);
// Recv Data NDArray
if (kv_msg->msg_type != kPullMsg) {
Message recv_data_msg;
CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS);
CHECK_GE(meta.data_shape_[2], 1);
std::vector<int64_t> vec_shape;
for (int i = 3; i < meta.data_shape_.size(); ++i) {
vec_shape.push_back(meta.data_shape_[i]);
}
kv_msg->data = CreateNDArrayFromRaw(
vec_shape,
DLDataType{kDLFloat, 32, 1},
DLContext{kDLCPU, 0},
recv_data_msg.data,
AUTO_FREE);
}
return kv_msg;
}
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg") DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
int args_count = 0; int args_count = 0;
...@@ -471,97 +585,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg") ...@@ -471,97 +585,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
kv_msg.data = args[args_count++]; kv_msg.data = args[args_count++];
} }
} }
int64_t kv_size = 0; send_kv_message(sender, &kv_msg, recv_id, AUTO_FREE);
char* kv_data = kv_msg.Serialize(&kv_size);
// Send kv_data
Message send_kv_msg;
send_kv_msg.data = kv_data;
send_kv_msg.size = kv_size;
send_kv_msg.deallocator = DefaultMessageDeleter;
CHECK_EQ(sender->Send(send_kv_msg, recv_id), ADD_SUCCESS);
if (kv_msg.msg_type != kFinalMsg &&
kv_msg.msg_type != kBarrierMsg &&
kv_msg.msg_type != kIPIDMsg) {
// Send ArrayMeta
ArrayMeta meta(kv_msg.msg_type);
meta.AddArray(kv_msg.id);
if (kv_msg.msg_type != kPullMsg) {
meta.AddArray(kv_msg.data);
}
int64_t meta_size = 0;
char* meta_data = meta.Serialize(&meta_size);
Message send_meta_msg;
send_meta_msg.data = meta_data;
send_meta_msg.size = meta_size;
send_meta_msg.deallocator = DefaultMessageDeleter;
CHECK_EQ(sender->Send(send_meta_msg, recv_id), ADD_SUCCESS);
// Send ID NDArray
Message send_id_msg;
send_id_msg.data = static_cast<char*>(kv_msg.id->data);
send_id_msg.size = kv_msg.id.GetSize();
NDArray id = kv_msg.id;
send_id_msg.deallocator = [id](Message*) {};
CHECK_EQ(sender->Send(send_id_msg, recv_id), ADD_SUCCESS);
// Send data NDArray
if (kv_msg.msg_type != kPullMsg) {
Message send_data_msg;
send_data_msg.data = static_cast<char*>(kv_msg.data->data);
send_data_msg.size = kv_msg.data.GetSize();
NDArray data = kv_msg.data;
send_data_msg.deallocator = [data](Message*) {};
CHECK_EQ(sender->Send(send_data_msg, recv_id), ADD_SUCCESS);
}
}
}); });
DGL_REGISTER_GLOBAL("network.CAPI_ReceiverRecvKVMsg") DGL_REGISTER_GLOBAL("network.CAPI_ReceiverRecvKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle); network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
KVStoreMsg *kv_msg = new KVStoreMsg(); *rv = recv_kv_message(receiver);
// Recv kv_Msg
Message recv_kv_msg;
int send_id;
CHECK_EQ(receiver->Recv(&recv_kv_msg, &send_id), REMOVE_SUCCESS);
kv_msg->Deserialize(recv_kv_msg.data, recv_kv_msg.size);
recv_kv_msg.deallocator(&recv_kv_msg);
if (kv_msg->msg_type == kFinalMsg ||
kv_msg->msg_type == kBarrierMsg ||
kv_msg->msg_type == kIPIDMsg) {
*rv = kv_msg;
return;
}
// Recv ArrayMeta
Message recv_meta_msg;
CHECK_EQ(receiver->RecvFrom(&recv_meta_msg, send_id), REMOVE_SUCCESS);
ArrayMeta meta(recv_meta_msg.data, recv_meta_msg.size);
recv_meta_msg.deallocator(&recv_meta_msg);
// Recv ID NDArray
Message recv_id_msg;
CHECK_EQ(receiver->RecvFrom(&recv_id_msg, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[0], 1);
kv_msg->id = CreateNDArrayFromRaw(
{meta.data_shape_[1]},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
recv_id_msg.data);
// Recv Data NDArray
if (kv_msg->msg_type != kPullMsg) {
Message recv_data_msg;
CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS);
CHECK_GE(meta.data_shape_[2], 1);
std::vector<int64_t> vec_shape;
for (int i = 3; i < meta.data_shape_.size(); ++i) {
vec_shape.push_back(meta.data_shape_[i]);
}
kv_msg->data = CreateNDArrayFromRaw(
vec_shape,
DLDataType{kDLFloat, 32, 1},
DLContext{kDLCPU, 0},
recv_data_msg.data);
}
*rv = kv_msg;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgType") DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgType")
...@@ -606,5 +637,120 @@ DGL_REGISTER_GLOBAL("network._CAPI_DeleteKVMsg") ...@@ -606,5 +637,120 @@ DGL_REGISTER_GLOBAL("network._CAPI_DeleteKVMsg")
delete msg; delete msg;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_FastPull")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string name = args[0];
int local_machine_id = args[1];
int machine_count = args[2];
int group_count = args[3];
int client_id = args[4];
NDArray ID = args[5];
NDArray pb = args[6];
NDArray local_data = args[7];
CommunicatorHandle chandle_sender = args[8];
CommunicatorHandle chandle_receiver = args[9];
std::string str_flag = args[10];
network::Sender* sender = static_cast<network::Sender*>(chandle_sender);
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle_receiver);
int64_t ID_size = ID.GetSize() / sizeof(int64_t);
int64_t* ID_data = static_cast<int64_t*>(ID->data);
int64_t* pb_data = static_cast<int64_t*>(pb->data);
char* local_data_char = static_cast<char*>(local_data->data);
std::vector<int64_t> local_ids;
std::vector<int64_t> local_ids_orginal;
std::vector<int64_t> local_data_shape;
std::vector<std::vector<int64_t> > remote_ids(machine_count);
std::vector<std::vector<int64_t> > remote_ids_original(machine_count);
unsigned int seed = 314;
int row_size = 1;
for (int i = 0; i < local_data->ndim; ++i) {
local_data_shape.push_back(local_data->shape[i]);
if (i != 0) {
row_size *= local_data->shape[i];
}
}
row_size *= sizeof(float);
// Get local id and remote id
if (str_flag.compare("has_g2l") == 0) {
NDArray g2l = args[11];
int64_t* g2l_data = static_cast<int64_t*>(g2l->data);
for (int64_t i = 0; i < ID_size; ++i) {
int64_t id = ID_data[i];
int64_t part_id = pb_data[id];
if (part_id == local_machine_id) {
int64_t local_id = g2l_data[id];
local_ids.push_back(local_id);
local_ids_orginal.push_back(i);
} else {
remote_ids[part_id].push_back(id);
remote_ids_original[part_id].push_back(i);
}
}
} else {
for (int64_t i = 0; i < ID_size; ++i) {
int64_t id = ID_data[i];
int64_t part_id = pb_data[id];
if (part_id == local_machine_id) {
local_ids.push_back(id);
local_ids_orginal.push_back(i);
} else {
remote_ids[part_id].push_back(id);
remote_ids_original[part_id].push_back(i);
}
}
}
int msg_count = 0;
for (int i = 0; i < remote_ids.size(); ++i) {
if (remote_ids[i].size() != 0) {
KVStoreMsg kv_msg;
kv_msg.msg_type = MessageType::kPullMsg;
kv_msg.rank = client_id;
kv_msg.name = name;
kv_msg.id = CreateNDArrayFromRaw({static_cast<int64_t>(remote_ids[i].size())},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
remote_ids[i].data(),
!AUTO_FREE);
int lower = i*group_count;
int higher = (i+1)*group_count-1;
#ifndef _WIN32 // windows does not support rand_r()
int s_id = (rand_r(&seed) % (higher-lower+1))+lower;
send_kv_message(sender, &kv_msg, s_id, !AUTO_FREE);
#endif
msg_count++;
}
}
char *return_data = new char[ID_size*row_size];
// Copy local data
#pragma omp parallel for
for (int64_t i = 0; i < local_ids.size(); ++i) {
memcpy(return_data + local_ids_orginal[i] * row_size,
local_data_char + local_ids[i] * row_size,
row_size);
}
// Recv remote message
for (int i = 0; i < msg_count; ++i) {
KVStoreMsg *kv_msg = recv_kv_message(receiver);
int64_t id_size = kv_msg->id.GetSize() / sizeof(int64_t);
int part_id = kv_msg->rank / group_count;
char* data_char = static_cast<char*>(kv_msg->data->data);
for (size_t n = 0; n < id_size; ++n) {
memcpy(return_data + remote_ids_original[part_id][n] * row_size,
data_char + n * row_size,
row_size);
}
delete kv_msg;
}
// Get final tensor
local_data_shape[0] = ID_size;
NDArray res_tensor = CreateNDArrayFromRaw(
local_data_shape,
DLDataType{kDLFloat, 32, 1},
DLContext{kDLCPU, 0},
return_data,
AUTO_FREE);
*rv = res_tensor;
});
} // namespace network } // namespace network
} // namespace dgl } // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment