"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "1b9304c0d5e79d88fe7cd9bacc39adb2df1e0c46"
Unverified Commit 338f24cf authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[KVstore] Fast-pull (#1446)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* fix lint
parent ad222fb9
......@@ -5,6 +5,7 @@ from ..network import _network_wait, _add_receiver_addr
from ..network import _receiver_wait, _sender_connect
from ..network import _send_kv_msg, _recv_kv_msg
from ..network import _clear_kv_msg
from ..network import _fast_pull
from ..network import KVMsgType, KVStoreMsg
from .. import backend as F
......@@ -146,6 +147,11 @@ class KVServer(object):
self._open_file_list = []
# record for total message count
self._msg_count = 0
# user-defined push handler
self._udf_push_handler = None
self._udf_push_param = None
# user-defined pull handler
self._udf_pull_handler = None
def __del__(self):
......@@ -317,6 +323,8 @@ class KVServer(object):
# Get connected with all client nodes
_receiver_wait(self._receiver, self._ip, self._port, self._client_count)
print("%d clients connected!" % self._client_count)
# recv client address information
addr_list = []
for i in range(self._client_count):
......@@ -378,14 +386,20 @@ class KVServer(object):
local_id = self._data_store[msg.name+'-g2l-'][msg.id]
else:
local_id = msg.id
self._push_handler(msg.name+'-data-', local_id, msg.data, self._data_store)
if self._udf_push_handler is not None:
self._udf_push_handler(msg.name+'-data-', local_id, msg.data, self._data_store, self._udf_push_param)
else:
self._default_push_handler(msg.name+'-data-', local_id, msg.data, self._data_store)
# Pull message
elif msg.type == KVMsgType.PULL:
if (msg.name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[msg.name+'-g2l-'][msg.id]
else:
local_id = msg.id
res_tensor = self._pull_handler(msg.name+'-data-', local_id, self._data_store)
if self._udf_pull_handler is not None:
res_tensor = self._udf_pull_handler(msg.name+'-data-', local_id, self._data_store)
else:
res_tensor = self._default_pull_handler(msg.name+'-data-', local_id, self._data_store)
back_msg = KVStoreMsg(
type=KVMsgType.PULL_BACK,
rank=self._server_id,
......@@ -500,7 +514,7 @@ class KVServer(object):
return data_shape
def _push_handler(self, name, ID, data, target):
def _default_push_handler(self, name, ID, data, target):
"""Default handler for PUSH message.
On default, _push_handler perform update operation for the tensor.
......@@ -519,7 +533,7 @@ class KVServer(object):
target[name][ID] = data
def _pull_handler(self, name, ID, target):
def _default_pull_handler(self, name, ID, target):
"""Default handler for PULL operation.
On default, _pull_handler perform get operation for the tensor.
......@@ -582,6 +596,7 @@ class KVClient(object):
self._server_namebook = server_namebook
self._server_count = len(server_namebook)
self._group_count = server_namebook[0][3]
self._machine_count = int(self._server_count / self._group_count)
# client ID will be assign by server after connecting to server
self._client_id = -1
# Get local machine id via server_namebook
......@@ -593,6 +608,11 @@ class KVClient(object):
self._open_file_list = []
# Gargage_collection
self._garbage_msg = []
# User-defined pull handler
self._udf_pull_handler = None
# User-defined push handler
self._udf_push_handler = None
self._udf_push_param = None
# Used load-balance
random.seed(time.time())
......@@ -812,7 +832,10 @@ class KVClient(object):
start += count[idx]
if local_id is not None: # local push
self._push_handler(name+'-data-', local_id, local_data, self._data_store)
if self._udf_push_handler is not None:
self._udf_push_handler(name+'-data-', local_id, local_data, self._data_store, self._udf_push_param)
else:
self._default_push_handler(name+'-data-', local_id, local_data, self._data_store)
def pull(self, name, id_tensor):
......@@ -833,73 +856,88 @@ class KVClient(object):
assert len(name) > 0, 'name cannot be empty.'
assert F.ndim(id_tensor) == 1, 'ID must be a vector.'
for msg in self._garbage_msg:
_clear_kv_msg(msg)
self._garbage_msg = []
# partition data
machine_id = self._data_store[name+'-part-'][id_tensor]
# sort index by machine id
sorted_id = F.tensor(np.argsort(F.asnumpy(machine_id)))
back_sorted_id = F.tensor(np.argsort(F.asnumpy(sorted_id)))
id_tensor = id_tensor[sorted_id]
machine, count = np.unique(F.asnumpy(machine_id), return_counts=True)
# pull data from server by order
start = 0
pull_count = 0
local_id = None
for idx in range(len(machine)):
end = start + count[idx]
if start == end: # No data for target machine
continue
partial_id = id_tensor[start:end]
if machine[idx] == self._machine_id: # local pull
# Note that DO NOT pull local data right now because we can overlap
# communication-local_pull here
if (name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[name+'-g2l-'][partial_id]
else:
local_id = partial_id
else: # pull data from remote server
msg = KVStoreMsg(
type=KVMsgType.PULL,
rank=self._client_id,
if self._udf_pull_handler is None: # Use fast-pull
g2l = None
if name+'-g2l-' in self._data_store:
g2l = self._data_store[name+'-g2l-']
return _fast_pull(name, id_tensor,
self._machine_count,
self._group_count,
self._machine_id,
self._client_id,
self._data_store[name+'-part-'],
g2l,
self._data_store[name+'-data-'],
self._sender,
self._receiver)
else:
for msg in self._garbage_msg:
_clear_kv_msg(msg)
self._garbage_msg = []
# partition data
machine_id = self._data_store[name+'-part-'][id_tensor]
# sort index by machine id
sorted_id = F.tensor(np.argsort(F.asnumpy(machine_id)))
back_sorted_id = F.tensor(np.argsort(F.asnumpy(sorted_id)))
id_tensor = id_tensor[sorted_id]
machine, count = np.unique(F.asnumpy(machine_id), return_counts=True)
# pull data from server by order
start = 0
pull_count = 0
local_id = None
for idx in range(len(machine)):
end = start + count[idx]
if start == end: # No data for target machine
continue
partial_id = id_tensor[start:end]
if machine[idx] == self._machine_id: # local pull
# Note that DO NOT pull local data right now because we can overlap
# communication-local_pull here
if (name+'-g2l-' in self._has_data) == True:
local_id = self._data_store[name+'-g2l-'][partial_id]
else:
local_id = partial_id
else: # pull data from remote server
msg = KVStoreMsg(
type=KVMsgType.PULL,
rank=self._client_id,
name=name,
id=partial_id,
data=None,
c_ptr=None)
# randomly select a server node in target machine for load-balance
s_id = random.randint(machine[idx]*self._group_count, (machine[idx]+1)*self._group_count-1)
_send_kv_msg(self._sender, msg, s_id)
pull_count += 1
start += count[idx]
msg_list = []
if local_id is not None: # local pull
local_data = self._udf_pull_handler(name+'-data-', local_id, self._data_store)
s_id = random.randint(self._machine_id*self._group_count, (self._machine_id+1)*self._group_count-1)
local_msg = KVStoreMsg(
type=KVMsgType.PULL_BACK,
rank=s_id,
name=name,
id=partial_id,
data=None,
id=None,
data=local_data,
c_ptr=None)
# randomly select a server node in target machine for load-balance
s_id = random.randint(machine[idx]*self._group_count, (machine[idx]+1)*self._group_count-1)
_send_kv_msg(self._sender, msg, s_id)
pull_count += 1
start += count[idx]
msg_list = []
if local_id is not None: # local pull
local_data = self._pull_handler(name+'-data-', local_id, self._data_store)
s_id = random.randint(self._machine_id*self._group_count, (self._machine_id+1)*self._group_count-1)
local_msg = KVStoreMsg(
type=KVMsgType.PULL_BACK,
rank=s_id,
name=name,
id=None,
data=local_data,
c_ptr=None)
msg_list.append(local_msg)
self._garbage_msg.append(local_msg)
msg_list.append(local_msg)
self._garbage_msg.append(local_msg)
# wait message from server nodes
for idx in range(pull_count):
remote_msg = _recv_kv_msg(self._receiver)
msg_list.append(remote_msg)
self._garbage_msg.append(remote_msg)
# wait message from server nodes
for idx in range(pull_count):
remote_msg = _recv_kv_msg(self._receiver)
msg_list.append(remote_msg)
self._garbage_msg.append(remote_msg)
# sort msg by server id and merge tensor together
msg_list.sort(key=self._takeId)
data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0)
# sort msg by server id and merge tensor together
msg_list.sort(key=self._takeId)
data_tensor = F.cat(seq=[msg.data for msg in msg_list], dim=0)
return data_tensor[back_sorted_id] # return data with original index order
return data_tensor[back_sorted_id] # return data with original index order
def barrier(self):
......@@ -1082,7 +1120,7 @@ class KVClient(object):
return elem.rank
def _push_handler(self, name, ID, data, target):
def _default_push_handler(self, name, ID, data, target):
"""Default handler for PUSH message.
On default, _push_handler perform update operation for the tensor.
......@@ -1099,26 +1137,4 @@ class KVClient(object):
self._data_store
"""
target[name][ID] = data
def _pull_handler(self, name, ID, target):
"""Default handler for PULL operation.
On default, _pull_handler perform get operation for the tensor.
Parameters
----------
name : str
data name
ID : tensor (mx.ndarray or torch.tensor)
a vector storing the ID list.
target : dict of data
self._data_store
Return
------
tensor
a tensor with the same row size of ID.
"""
return target[name][ID]
\ No newline at end of file
......@@ -327,3 +327,56 @@ def _clear_kv_msg(msg):
F.sync()
if msg.c_ptr is not None:
_CAPI_DeleteKVMsg(msg.c_ptr)
def _fast_pull(name, id_tensor,
machine_count, group_count, machine_id, client_id,
partition_book, g2l, local_data,
sender, receiver):
""" Pull message
Parameters
----------
name : str
data name string
id_tensor : tensor
tensor of ID
machine_count : int
count of total machine
group_count : int
count of server group
machine_id : int
current machine id
client_id : int
current client ID
partition_book : tensor
tensor of partition book
g2l : tensor
tensor of global2local
local_data : tensor
tensor of local shared data
sender : ctypes.c_void_p
C Sender handle
receiver : ctypes.c_void_p
C Receiver handle
Return
------
tensor
target tensor
"""
if g2l is not None:
res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id,
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data),
sender, receiver, 'has_g2l',
F.zerocopy_to_dgl_ndarray(g2l))
else:
res_tensor = _CAPI_FastPull(name, machine_id, machine_count, group_count, client_id,
F.zerocopy_to_dgl_ndarray(id_tensor),
F.zerocopy_to_dgl_ndarray(partition_book),
F.zerocopy_to_dgl_ndarray(local_data),
sender, receiver, 'no_g2l')
return F.zerocopy_from_dgl_ndarray(res_tensor)
......@@ -5,6 +5,8 @@
*/
#include "./network.h"
#include <stdlib.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/packed_func_ext.h>
......@@ -21,6 +23,8 @@
using dgl::network::StringPrintf;
using namespace dgl::runtime;
const bool AUTO_FREE = true;
namespace dgl {
namespace network {
......@@ -34,7 +38,8 @@ static void NaiveDeleter(DLManagedTensor* managed_tensor) {
NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx,
void* raw) {
void* raw,
bool auto_free) {
DLTensor tensor;
tensor.ctx = ctx;
tensor.ndim = static_cast<int>(shape.size());
......@@ -53,7 +58,9 @@ NDArray CreateNDArrayFromRaw(std::vector<int64_t> shape,
tensor.data = raw;
DLManagedTensor *managed_tensor = new DLManagedTensor();
managed_tensor->dl_tensor = tensor;
managed_tensor->deleter = NaiveDeleter;
if (auto_free) {
managed_tensor->deleter = NaiveDeleter;
}
return NDArray::FromDLPack(managed_tensor);
}
......@@ -382,7 +389,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[1]},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
array_0.data);
array_0.data,
AUTO_FREE);
// edge_mapping
Message array_1;
CHECK_EQ(receiver->RecvFrom(&array_1, send_id), REMOVE_SUCCESS);
......@@ -391,7 +399,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[3]},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
array_1.data);
array_1.data,
AUTO_FREE);
// layer_offset
Message array_2;
CHECK_EQ(receiver->RecvFrom(&array_2, send_id), REMOVE_SUCCESS);
......@@ -400,7 +409,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[5]},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
array_2.data);
array_2.data,
AUTO_FREE);
// flow_offset
Message array_3;
CHECK_EQ(receiver->RecvFrom(&array_3, send_id), REMOVE_SUCCESS);
......@@ -409,7 +419,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[7]},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
array_3.data);
array_3.data,
AUTO_FREE);
// CSR indptr
Message array_4;
CHECK_EQ(receiver->RecvFrom(&array_4, send_id), REMOVE_SUCCESS);
......@@ -418,7 +429,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[9]},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
array_4.data);
array_4.data,
AUTO_FREE);
// CSR indice
Message array_5;
CHECK_EQ(receiver->RecvFrom(&array_5, send_id), REMOVE_SUCCESS);
......@@ -427,7 +439,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[11]},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
array_5.data);
array_5.data,
AUTO_FREE);
// CSR edge_ids
Message array_6;
CHECK_EQ(receiver->RecvFrom(&array_6, send_id), REMOVE_SUCCESS);
......@@ -436,7 +449,8 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
{meta.data_shape_[13]},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
array_6.data);
array_6.data,
AUTO_FREE);
// Create CSR
CSRPtr csr(new CSR(indptr, indice, edge_ids));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr));
......@@ -452,6 +466,106 @@ DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
////////////////////////// Distributed KVStore Components ////////////////////////////////
static void send_kv_message(network::Sender* sender,
KVStoreMsg* kv_msg,
int recv_id,
bool auto_free) {
int64_t kv_size = 0;
char* kv_data = kv_msg->Serialize(&kv_size);
// Send kv_data
Message send_kv_msg;
send_kv_msg.data = kv_data;
send_kv_msg.size = kv_size;
if (auto_free) {
send_kv_msg.deallocator = DefaultMessageDeleter;
}
CHECK_EQ(sender->Send(send_kv_msg, recv_id), ADD_SUCCESS);
if (kv_msg->msg_type != kFinalMsg &&
kv_msg->msg_type != kBarrierMsg &&
kv_msg->msg_type != kIPIDMsg) {
// Send ArrayMeta
ArrayMeta meta(kv_msg->msg_type);
meta.AddArray(kv_msg->id);
if (kv_msg->msg_type != kPullMsg) {
meta.AddArray(kv_msg->data);
}
int64_t meta_size = 0;
char* meta_data = meta.Serialize(&meta_size);
Message send_meta_msg;
send_meta_msg.data = meta_data;
send_meta_msg.size = meta_size;
if (auto_free) {
send_meta_msg.deallocator = DefaultMessageDeleter;
}
CHECK_EQ(sender->Send(send_meta_msg, recv_id), ADD_SUCCESS);
// Send ID NDArray
Message send_id_msg;
send_id_msg.data = static_cast<char*>(kv_msg->id->data);
send_id_msg.size = kv_msg->id.GetSize();
NDArray id = kv_msg->id;
send_id_msg.deallocator = [id](Message*) {};
CHECK_EQ(sender->Send(send_id_msg, recv_id), ADD_SUCCESS);
// Send data NDArray
if (kv_msg->msg_type != kPullMsg) {
Message send_data_msg;
send_data_msg.data = static_cast<char*>(kv_msg->data->data);
send_data_msg.size = kv_msg->data.GetSize();
NDArray data = kv_msg->data;
if (auto_free) {
send_data_msg.deallocator = [data](Message*) {};
}
CHECK_EQ(sender->Send(send_data_msg, recv_id), ADD_SUCCESS);
}
}
}
static KVStoreMsg* recv_kv_message(network::Receiver* receiver) {
KVStoreMsg *kv_msg = new KVStoreMsg();
// Recv kv_Msg
Message recv_kv_msg;
int send_id;
CHECK_EQ(receiver->Recv(&recv_kv_msg, &send_id), REMOVE_SUCCESS);
kv_msg->Deserialize(recv_kv_msg.data, recv_kv_msg.size);
recv_kv_msg.deallocator(&recv_kv_msg);
if (kv_msg->msg_type == kFinalMsg ||
kv_msg->msg_type == kBarrierMsg ||
kv_msg->msg_type == kIPIDMsg) {
return kv_msg;
}
// Recv ArrayMeta
Message recv_meta_msg;
CHECK_EQ(receiver->RecvFrom(&recv_meta_msg, send_id), REMOVE_SUCCESS);
ArrayMeta meta(recv_meta_msg.data, recv_meta_msg.size);
recv_meta_msg.deallocator(&recv_meta_msg);
// Recv ID NDArray
Message recv_id_msg;
CHECK_EQ(receiver->RecvFrom(&recv_id_msg, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[0], 1);
kv_msg->id = CreateNDArrayFromRaw(
{meta.data_shape_[1]},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
recv_id_msg.data,
AUTO_FREE);
// Recv Data NDArray
if (kv_msg->msg_type != kPullMsg) {
Message recv_data_msg;
CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS);
CHECK_GE(meta.data_shape_[2], 1);
std::vector<int64_t> vec_shape;
for (int i = 3; i < meta.data_shape_.size(); ++i) {
vec_shape.push_back(meta.data_shape_[i]);
}
kv_msg->data = CreateNDArrayFromRaw(
vec_shape,
DLDataType{kDLFloat, 32, 1},
DLContext{kDLCPU, 0},
recv_data_msg.data,
AUTO_FREE);
}
return kv_msg;
}
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int args_count = 0;
......@@ -471,97 +585,14 @@ DGL_REGISTER_GLOBAL("network._CAPI_SenderSendKVMsg")
kv_msg.data = args[args_count++];
}
}
int64_t kv_size = 0;
char* kv_data = kv_msg.Serialize(&kv_size);
// Send kv_data
Message send_kv_msg;
send_kv_msg.data = kv_data;
send_kv_msg.size = kv_size;
send_kv_msg.deallocator = DefaultMessageDeleter;
CHECK_EQ(sender->Send(send_kv_msg, recv_id), ADD_SUCCESS);
if (kv_msg.msg_type != kFinalMsg &&
kv_msg.msg_type != kBarrierMsg &&
kv_msg.msg_type != kIPIDMsg) {
// Send ArrayMeta
ArrayMeta meta(kv_msg.msg_type);
meta.AddArray(kv_msg.id);
if (kv_msg.msg_type != kPullMsg) {
meta.AddArray(kv_msg.data);
}
int64_t meta_size = 0;
char* meta_data = meta.Serialize(&meta_size);
Message send_meta_msg;
send_meta_msg.data = meta_data;
send_meta_msg.size = meta_size;
send_meta_msg.deallocator = DefaultMessageDeleter;
CHECK_EQ(sender->Send(send_meta_msg, recv_id), ADD_SUCCESS);
// Send ID NDArray
Message send_id_msg;
send_id_msg.data = static_cast<char*>(kv_msg.id->data);
send_id_msg.size = kv_msg.id.GetSize();
NDArray id = kv_msg.id;
send_id_msg.deallocator = [id](Message*) {};
CHECK_EQ(sender->Send(send_id_msg, recv_id), ADD_SUCCESS);
// Send data NDArray
if (kv_msg.msg_type != kPullMsg) {
Message send_data_msg;
send_data_msg.data = static_cast<char*>(kv_msg.data->data);
send_data_msg.size = kv_msg.data.GetSize();
NDArray data = kv_msg.data;
send_data_msg.deallocator = [data](Message*) {};
CHECK_EQ(sender->Send(send_data_msg, recv_id), ADD_SUCCESS);
}
}
send_kv_message(sender, &kv_msg, recv_id, AUTO_FREE);
});
DGL_REGISTER_GLOBAL("network.CAPI_ReceiverRecvKVMsg")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
KVStoreMsg *kv_msg = new KVStoreMsg();
// Recv kv_Msg
Message recv_kv_msg;
int send_id;
CHECK_EQ(receiver->Recv(&recv_kv_msg, &send_id), REMOVE_SUCCESS);
kv_msg->Deserialize(recv_kv_msg.data, recv_kv_msg.size);
recv_kv_msg.deallocator(&recv_kv_msg);
if (kv_msg->msg_type == kFinalMsg ||
kv_msg->msg_type == kBarrierMsg ||
kv_msg->msg_type == kIPIDMsg) {
*rv = kv_msg;
return;
}
// Recv ArrayMeta
Message recv_meta_msg;
CHECK_EQ(receiver->RecvFrom(&recv_meta_msg, send_id), REMOVE_SUCCESS);
ArrayMeta meta(recv_meta_msg.data, recv_meta_msg.size);
recv_meta_msg.deallocator(&recv_meta_msg);
// Recv ID NDArray
Message recv_id_msg;
CHECK_EQ(receiver->RecvFrom(&recv_id_msg, send_id), REMOVE_SUCCESS);
CHECK_EQ(meta.data_shape_[0], 1);
kv_msg->id = CreateNDArrayFromRaw(
{meta.data_shape_[1]},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
recv_id_msg.data);
// Recv Data NDArray
if (kv_msg->msg_type != kPullMsg) {
Message recv_data_msg;
CHECK_EQ(receiver->RecvFrom(&recv_data_msg, send_id), REMOVE_SUCCESS);
CHECK_GE(meta.data_shape_[2], 1);
std::vector<int64_t> vec_shape;
for (int i = 3; i < meta.data_shape_.size(); ++i) {
vec_shape.push_back(meta.data_shape_[i]);
}
kv_msg->data = CreateNDArrayFromRaw(
vec_shape,
DLDataType{kDLFloat, 32, 1},
DLContext{kDLCPU, 0},
recv_data_msg.data);
}
*rv = kv_msg;
*rv = recv_kv_message(receiver);
});
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverGetKVMsgType")
......@@ -606,5 +637,120 @@ DGL_REGISTER_GLOBAL("network._CAPI_DeleteKVMsg")
delete msg;
});
DGL_REGISTER_GLOBAL("network._CAPI_FastPull")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string name = args[0];
int local_machine_id = args[1];
int machine_count = args[2];
int group_count = args[3];
int client_id = args[4];
NDArray ID = args[5];
NDArray pb = args[6];
NDArray local_data = args[7];
CommunicatorHandle chandle_sender = args[8];
CommunicatorHandle chandle_receiver = args[9];
std::string str_flag = args[10];
network::Sender* sender = static_cast<network::Sender*>(chandle_sender);
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle_receiver);
int64_t ID_size = ID.GetSize() / sizeof(int64_t);
int64_t* ID_data = static_cast<int64_t*>(ID->data);
int64_t* pb_data = static_cast<int64_t*>(pb->data);
char* local_data_char = static_cast<char*>(local_data->data);
std::vector<int64_t> local_ids;
std::vector<int64_t> local_ids_orginal;
std::vector<int64_t> local_data_shape;
std::vector<std::vector<int64_t> > remote_ids(machine_count);
std::vector<std::vector<int64_t> > remote_ids_original(machine_count);
unsigned int seed = 314;
int row_size = 1;
for (int i = 0; i < local_data->ndim; ++i) {
local_data_shape.push_back(local_data->shape[i]);
if (i != 0) {
row_size *= local_data->shape[i];
}
}
row_size *= sizeof(float);
// Get local id and remote id
if (str_flag.compare("has_g2l") == 0) {
NDArray g2l = args[11];
int64_t* g2l_data = static_cast<int64_t*>(g2l->data);
for (int64_t i = 0; i < ID_size; ++i) {
int64_t id = ID_data[i];
int64_t part_id = pb_data[id];
if (part_id == local_machine_id) {
int64_t local_id = g2l_data[id];
local_ids.push_back(local_id);
local_ids_orginal.push_back(i);
} else {
remote_ids[part_id].push_back(id);
remote_ids_original[part_id].push_back(i);
}
}
} else {
for (int64_t i = 0; i < ID_size; ++i) {
int64_t id = ID_data[i];
int64_t part_id = pb_data[id];
if (part_id == local_machine_id) {
local_ids.push_back(id);
local_ids_orginal.push_back(i);
} else {
remote_ids[part_id].push_back(id);
remote_ids_original[part_id].push_back(i);
}
}
}
int msg_count = 0;
for (int i = 0; i < remote_ids.size(); ++i) {
if (remote_ids[i].size() != 0) {
KVStoreMsg kv_msg;
kv_msg.msg_type = MessageType::kPullMsg;
kv_msg.rank = client_id;
kv_msg.name = name;
kv_msg.id = CreateNDArrayFromRaw({static_cast<int64_t>(remote_ids[i].size())},
DLDataType{kDLInt, 64, 1},
DLContext{kDLCPU, 0},
remote_ids[i].data(),
!AUTO_FREE);
int lower = i*group_count;
int higher = (i+1)*group_count-1;
#ifndef _WIN32 // windows does not support rand_r()
int s_id = (rand_r(&seed) % (higher-lower+1))+lower;
send_kv_message(sender, &kv_msg, s_id, !AUTO_FREE);
#endif
msg_count++;
}
}
char *return_data = new char[ID_size*row_size];
// Copy local data
#pragma omp parallel for
for (int64_t i = 0; i < local_ids.size(); ++i) {
memcpy(return_data + local_ids_orginal[i] * row_size,
local_data_char + local_ids[i] * row_size,
row_size);
}
// Recv remote message
for (int i = 0; i < msg_count; ++i) {
KVStoreMsg *kv_msg = recv_kv_message(receiver);
int64_t id_size = kv_msg->id.GetSize() / sizeof(int64_t);
int part_id = kv_msg->rank / group_count;
char* data_char = static_cast<char*>(kv_msg->data->data);
for (size_t n = 0; n < id_size; ++n) {
memcpy(return_data + remote_ids_original[part_id][n] * row_size,
data_char + n * row_size,
row_size);
}
delete kv_msg;
}
// Get final tensor
local_data_shape[0] = ID_size;
NDArray res_tensor = CreateNDArrayFromRaw(
local_data_shape,
DLDataType{kDLFloat, 32, 1},
DLContext{kDLCPU, 0},
return_data,
AUTO_FREE);
*rv = res_tensor;
});
} // namespace network
} // namespace dgl
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment