"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c8bb1ff53ee1210e04861abd41518f197cfd8b3c"
Unverified Commit c3516f1a authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[Network] Refactoring Communicator (#679)

* Refactoring Communicator

* fix lint

* change non-const reference

* add header file

* use MemoryBuffer

* update PR

* fix bug on csr shape

* zero-copy msg_queue

* fix lint

* fix lint

* fix lint

* add header file

* fix windows build error

* fix windows build error

* update

* fix lint

* update

* fix lint

* fix lint

* add more test

* fix windows test

* update windows test

* update windows test

* update windows test

* update

* fix lint

* fix lint

* update

* update

* update

* update

* use STATUS code

* update test

* remove mem_cpy

* fix lint

* update

* finish

* ConstructNFTensor

* add test for deallocator

* update

* fix lint
parent fc9d30fa
...@@ -3,20 +3,20 @@ from ...network import _send_nodeflow, _recv_nodeflow ...@@ -3,20 +3,20 @@ from ...network import _send_nodeflow, _recv_nodeflow
from ...network import _create_sender, _create_receiver from ...network import _create_sender, _create_receiver
from ...network import _finalize_sender, _finalize_receiver from ...network import _finalize_sender, _finalize_receiver
from ...network import _add_receiver_addr, _sender_connect from ...network import _add_receiver_addr, _sender_connect
from ...network import _receiver_wait, _send_end_signal from ...network import _receiver_wait, _send_sampler_end_signal
from multiprocessing import Pool from multiprocessing import Pool
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
class SamplerPool(object): class SamplerPool(object):
"""SamplerPool is an abstract class, in which the worker method """SamplerPool is an abstract class, in which the worker() method
should be implemented by users. SamplerPool will fork() N (N = num_worker) should be implemented by users. SamplerPool will fork() N (N = num_worker)
child processes, and each process will perform worker() method independently. child processes, and each process will perform worker() method independently.
Note that, the fork() API will use shared memory for N process and the OS will Note that, the fork() API uses shared memory for N processes and the OS will
perfrom copy-on-write only when developers write that piece of memory. So fork N perfrom copy-on-write on that only when developers write that piece of memory.
processes and load N copy of graph will not increase the memory overhead. So fork N processes and load N copies of graph will not increase the memory overhead.
Users can use this class like this: For example, users can use this class like this:
class MySamplerPool(SamplerPool): class MySamplerPool(SamplerPool):
...@@ -37,13 +37,13 @@ class SamplerPool(object): ...@@ -37,13 +37,13 @@ class SamplerPool(object):
Parameters Parameters
---------- ----------
num_worker : int num_worker : int
number of worker (number of child process) number of child process
args : arguments args : arguments
any arguments passed by user any arguments passed by user
""" """
p = Pool() p = Pool()
for i in range(num_worker): for i in range(num_worker):
print("Start child process %d ..." % i) print("Start child sampler process %d ..." % i)
p.apply_async(self.worker, args=(args,)) p.apply_async(self.worker, args=(args,))
# Waiting for all subprocesses done ... # Waiting for all subprocesses done ...
p.close() p.close()
...@@ -51,7 +51,7 @@ class SamplerPool(object): ...@@ -51,7 +51,7 @@ class SamplerPool(object):
@abstractmethod @abstractmethod
def worker(self, args): def worker(self, args):
"""User-defined function """User-defined function for worker
Parameters Parameters
---------- ----------
...@@ -63,28 +63,34 @@ class SamplerPool(object): ...@@ -63,28 +63,34 @@ class SamplerPool(object):
class SamplerSender(object): class SamplerSender(object):
"""SamplerSender for DGL distributed training. """SamplerSender for DGL distributed training.
Users use SamplerSender to send sampled subgraph (NodeFlow) Users use SamplerSender to send sampled subgraphs (NodeFlow)
to remote SamplerReceiver. Note that a SamplerSender can connect to remote SamplerReceiver. Note that, a SamplerSender can connect
to multiple SamplerReceiver. to multiple SamplerReceiver currently. The underlying implementation
will send different subgraphs to different SamplerReceiver in parallel
via multi-threading.
Parameters Parameters
---------- ----------
namebook : dict namebook : dict
address namebook of SamplerReceiver, where IP address namebook of SamplerReceiver, where the
key is recevier's ID and value is receiver's address, e.g., key is recevier's ID (start from 0) and value is receiver's address, e.g.,
{ 0:'168.12.23.45:50051', { 0:'168.12.23.45:50051',
1:'168.12.23.21:50051', 1:'168.12.23.21:50051',
2:'168.12.46.12:50051' } 2:'168.12.46.12:50051' }
net_type : str
networking type, e.g., 'socket' (default) or 'mpi'.
""" """
def __init__(self, namebook): def __init__(self, namebook, net_type='socket'):
assert len(namebook) > 0, 'namebook cannot be empty.' assert len(namebook) > 0, 'namebook cannot be empty.'
assert net_type in ('socket', 'mpi'), 'Unknown network type.'
self._namebook = namebook self._namebook = namebook
self._sender = _create_sender() self._sender = _create_sender(net_type)
for ID, addr in self._namebook.items(): for ID, addr in self._namebook.items():
vec = addr.split(':') ip_port = addr.split(':')
_add_receiver_addr(self._sender, vec[0], int(vec[1]), ID) assert len(ip_port) == 2, 'Uncorrect format of IP address.'
_add_receiver_addr(self._sender, ip_port[0], int(ip_port[1]), ID)
_sender_connect(self._sender) _sender_connect(self._sender)
def __del__(self): def __del__(self):
...@@ -93,36 +99,58 @@ class SamplerSender(object): ...@@ -93,36 +99,58 @@ class SamplerSender(object):
_finalize_sender(self._sender) _finalize_sender(self._sender)
def send(self, nodeflow, recv_id): def send(self, nodeflow, recv_id):
"""Send sampled subgraph (NodeFlow) to remote trainer. """Send sampled subgraph (NodeFlow) to remote trainer. Note that,
the send() API is non-blocking and it returns immediately if the
underlying message queue is not full.
Parameters Parameters
---------- ----------
nodeflow : NodeFlow nodeflow : NodeFlow
sampled NodeFlow object sampled NodeFlow
recv_id : int recv_id : int
receiver ID receiver's ID
""" """
assert recv_id >= 0, 'recv_id cannot be a negative number.'
_send_nodeflow(self._sender, nodeflow, recv_id) _send_nodeflow(self._sender, nodeflow, recv_id)
def batch_send(self, nf_list, id_list):
"""Send a batch of subgraphs (Nodeflow) to remote trainer. Note that,
the batch_send() API is non-blocking and it returns immediately if the
underlying message queue is not full.
Parameters
----------
nf_list : list
a list of NodeFlow object
id_list : list
a list of recv_id
"""
assert len(nf_list) > 0, 'nf_list cannot be empty.'
assert len(nf_list) == len(id_list), 'The length of nf_list must be equal to id_list.'
for i in range(len(nf_list)):
assert id_list[i] >= 0, 'recv_id cannot be a negative number.'
_send_nodeflow(self._sender, nf_list[i], id_list[i])
def signal(self, recv_id): def signal(self, recv_id):
"""Whene samplling of each epoch is finished, users can """When the samplling of each epoch is finished, users can
invoke this API to tell SamplerReceiver it has finished its job. invoke this API to tell SamplerReceiver that sampler has finished its job.
Parameters Parameters
---------- ----------
recv_id : int recv_id : int
receiver ID receiver's ID
""" """
_send_end_signal(self._sender, recv_id) assert recv_id >= 0, 'recv_id cannot be a negative number.'
_send_sampler_end_signal(self._sender, recv_id)
class SamplerReceiver(object): class SamplerReceiver(object):
"""SamplerReceiver for DGL distributed training. """SamplerReceiver for DGL distributed training.
Users use SamplerReceiver to receive sampled subgraph (NodeFlow) Users use SamplerReceiver to receive sampled subgraphs (NodeFlow)
from remote SamplerSender. Note that SamplerReceiver can receive messages from remote SamplerSender. Note that SamplerReceiver can receive messages
from multiple SamplerSenders concurrently by given the num_sender parameter. from multiple SamplerSenders concurrently by given the num_sender parameter.
Note that, only when all SamplerSenders connect to SamplerReceiver, receiver Only when all SamplerSenders connected to SamplerReceiver successfully,
can start its job. SamplerReceiver can start its job.
Parameters Parameters
---------- ----------
...@@ -132,15 +160,20 @@ class SamplerReceiver(object): ...@@ -132,15 +160,20 @@ class SamplerReceiver(object):
address of SamplerReceiver, e.g., '127.0.0.1:50051' address of SamplerReceiver, e.g., '127.0.0.1:50051'
num_sender : int num_sender : int
total number of SamplerSender total number of SamplerSender
net_type : str
networking type, e.g., 'socket' (default) or 'mpi'.
""" """
def __init__(self, graph, addr, num_sender): def __init__(self, graph, addr, num_sender, net_type='socket'):
assert num_sender > 0, 'num_sender must be large than zero.'
assert net_type in ('socket', 'mpi'), 'Unknown network type.'
self._graph = graph self._graph = graph
self._addr = addr self._addr = addr
self._num_sender = num_sender self._num_sender = num_sender
self._tmp_count = 0 self._tmp_count = 0
self._receiver = _create_receiver() self._receiver = _create_receiver(net_type)
vec = self._addr.split(':') ip_port = addr.split(':')
_receiver_wait(self._receiver, vec[0], int(vec[1]), self._num_sender); assert len(ip_port) == 2, 'Uncorrect format of IP address.'
_receiver_wait(self._receiver, ip_port[0], int(ip_port[1]), num_sender);
def __del__(self): def __del__(self):
"""Finalize Receiver """Finalize Receiver
...@@ -148,7 +181,7 @@ class SamplerReceiver(object): ...@@ -148,7 +181,7 @@ class SamplerReceiver(object):
_finalize_receiver(self._receiver) _finalize_receiver(self._receiver)
def __iter__(self): def __iter__(self):
"""Iterator """Sampler iterator
""" """
return self return self
...@@ -157,10 +190,10 @@ class SamplerReceiver(object): ...@@ -157,10 +190,10 @@ class SamplerReceiver(object):
""" """
while True: while True:
res = _recv_nodeflow(self._receiver, self._graph) res = _recv_nodeflow(self._receiver, self._graph)
if isinstance(res, int): if isinstance(res, int): # recv an end-signal
self._tmp_count += 1 self._tmp_count += 1
if self._tmp_count == self._num_sender: if self._tmp_count == self._num_sender:
self._tmp_count = 0 self._tmp_count = 0
raise StopIteration raise StopIteration
else: else:
return res return res # recv a nodeflow
...@@ -7,13 +7,30 @@ from . import utils ...@@ -7,13 +7,30 @@ from . import utils
_init_api("dgl.network") _init_api("dgl.network")
_CONTROL_NODEFLOW = 0
_CONTROL_END_SIGNAL = 1
def _create_sender(): ################################ Common Network Components ##################################
def _create_sender(net_type):
"""Create a Sender communicator via C api """Create a Sender communicator via C api
Parameters
----------
net_type : str
'socket' or 'mpi'
""" """
return _CAPI_DGLSenderCreate() assert net_type in ('socket', 'mpi'), 'Unknown network type.'
return _CAPI_DGLSenderCreate(net_type)
def _create_receiver(net_type):
"""Create a Receiver communicator via C api
Parameters
----------
net_type : str
'socket' or 'mpi'
"""
assert net_type in ('socket', 'mpi'), 'Unknown network type.'
return _CAPI_DGLReceiverCreate(net_type)
def _finalize_sender(sender): def _finalize_sender(sender):
"""Finalize Sender communicator """Finalize Sender communicator
...@@ -25,6 +42,11 @@ def _finalize_sender(sender): ...@@ -25,6 +42,11 @@ def _finalize_sender(sender):
""" """
_CAPI_DGLFinalizeSender(sender) _CAPI_DGLFinalizeSender(sender)
def _finalize_receiver(receiver):
"""Finalize Receiver Communicator
"""
_CAPI_DGLFinalizeReceiver(receiver)
def _add_receiver_addr(sender, ip_addr, port, recv_id): def _add_receiver_addr(sender, ip_addr, port, recv_id):
"""Add Receiver IP address to namebook """Add Receiver IP address to namebook
...@@ -39,6 +61,7 @@ def _add_receiver_addr(sender, ip_addr, port, recv_id): ...@@ -39,6 +61,7 @@ def _add_receiver_addr(sender, ip_addr, port, recv_id):
recv_id : int recv_id : int
Receiver ID Receiver ID
""" """
assert recv_id >= 0, 'recv_id cannot be a negative number.'
_CAPI_DGLSenderAddReceiver(sender, ip_addr, int(port), int(recv_id)) _CAPI_DGLSenderAddReceiver(sender, ip_addr, int(port), int(recv_id))
def _sender_connect(sender): def _sender_connect(sender):
...@@ -51,6 +74,27 @@ def _sender_connect(sender): ...@@ -51,6 +74,27 @@ def _sender_connect(sender):
""" """
_CAPI_DGLSenderConnect(sender) _CAPI_DGLSenderConnect(sender)
def _receiver_wait(receiver, ip_addr, port, num_sender):
"""Wait all Sender to connect..
Parameters
----------
receiver : ctypes.c_void_p
C Receiver handle
ip_addr : str
IP address of Receiver
port : int
port of Receiver
num_sender : int
total number of Sender
"""
assert num_sender >= 0, 'num_sender cannot be a negative number.'
_CAPI_DGLReceiverWait(receiver, ip_addr, int(port), int(num_sender))
################################ Distributed Sampler Components ################################
def _send_nodeflow(sender, nodeflow, recv_id): def _send_nodeflow(sender, nodeflow, recv_id):
"""Send sampled subgraph (Nodeflow) to remote Receiver. """Send sampled subgraph (Nodeflow) to remote Receiver.
...@@ -63,12 +107,13 @@ def _send_nodeflow(sender, nodeflow, recv_id): ...@@ -63,12 +107,13 @@ def _send_nodeflow(sender, nodeflow, recv_id):
recv_id : int recv_id : int
Receiver ID Receiver ID
""" """
assert recv_id >= 0, 'recv_id cannot be a negative number.'
gidx = nodeflow._graph gidx = nodeflow._graph
node_mapping = nodeflow._node_mapping.todgltensor() node_mapping = nodeflow._node_mapping.todgltensor()
edge_mapping = nodeflow._edge_mapping.todgltensor() edge_mapping = nodeflow._edge_mapping.todgltensor()
layers_offsets = utils.toindex(nodeflow._layer_offsets).todgltensor() layers_offsets = utils.toindex(nodeflow._layer_offsets).todgltensor()
flows_offsets = utils.toindex(nodeflow._block_offsets).todgltensor() flows_offsets = utils.toindex(nodeflow._block_offsets).todgltensor()
_CAPI_SenderSendSubgraph(sender, _CAPI_SenderSendNodeFlow(sender,
int(recv_id), int(recv_id),
gidx, gidx,
node_mapping, node_mapping,
...@@ -76,7 +121,7 @@ def _send_nodeflow(sender, nodeflow, recv_id): ...@@ -76,7 +121,7 @@ def _send_nodeflow(sender, nodeflow, recv_id):
layers_offsets, layers_offsets,
flows_offsets) flows_offsets)
def _send_end_signal(sender, recv_id): def _send_sampler_end_signal(sender, recv_id):
"""Send an epoch-end signal to remote Receiver. """Send an epoch-end signal to remote Receiver.
Parameters Parameters
...@@ -86,33 +131,8 @@ def _send_end_signal(sender, recv_id): ...@@ -86,33 +131,8 @@ def _send_end_signal(sender, recv_id):
recv_id : int recv_id : int
Receiver ID Receiver ID
""" """
_CAPI_SenderSendEndSignal(sender, int(recv_id)) assert recv_id >= 0, 'recv_id cannot be a negative number.'
_CAPI_SenderSendSamplerEndSignal(sender, int(recv_id))
def _create_receiver():
"""Create a Receiver communicator via C api
"""
return _CAPI_DGLReceiverCreate()
def _finalize_receiver(receiver):
"""Finalize Receiver Communicator
"""
_CAPI_DGLFinalizeReceiver(receiver)
def _receiver_wait(receiver, ip_addr, port, num_sender):
"""Wait all Sender to connect..
Parameters
----------
receiver : ctypes.c_void_p
C Receiver handle
ip_addr : str
IP address of Receiver
port : int
port of Receiver
num_sender : int
total number of Sender
"""
_CAPI_DGLReceiverWait(receiver, ip_addr, int(port), int(num_sender))
def _recv_nodeflow(receiver, graph): def _recv_nodeflow(receiver, graph):
"""Receive sampled subgraph (NodeFlow) from remote sampler. """Receive sampled subgraph (NodeFlow) from remote sampler.
...@@ -126,15 +146,10 @@ def _recv_nodeflow(receiver, graph): ...@@ -126,15 +146,10 @@ def _recv_nodeflow(receiver, graph):
Returns Returns
------- -------
NodeFlow NodeFlow or an end-signal
Sampled NodeFlow object
""" """
res = _CAPI_ReceiverRecvSubgraph(receiver) res = _CAPI_ReceiverRecvNodeFlow(receiver)
if isinstance(res, int): if isinstance(res, int):
if res == _CONTROL_END_SIGNAL: return res
return _CONTROL_END_SIGNAL
else:
raise RuntimeError('Got unexpected control code {}'.format(res))
else: else:
# res is of type List<NodeFlowObject> return NodeFlow(graph, res)
return NodeFlow(graph, res[0])
...@@ -3,55 +3,120 @@ ...@@ -3,55 +3,120 @@
* \file graph/network.cc * \file graph/network.cc
* \brief DGL networking related APIs * \brief DGL networking related APIs
*/ */
#include "./network.h"
#include <dgl/runtime/container.h> #include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/packed_func_ext.h> #include <dgl/packed_func_ext.h>
#include "./network.h" #include <dgl/immutable_graph.h>
#include <dgl/nodeflow.h>
#include <unordered_map>
#include "./network/communicator.h" #include "./network/communicator.h"
#include "./network/socket_communicator.h" #include "./network/socket_communicator.h"
#include "./network/serialize.h" #include "./network/msg_queue.h"
#include "./network/common.h"
#include "../c_api_common.h"
using dgl::network::StringPrintf;
using namespace dgl::runtime; using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace network { namespace network {
// Wrapper for Send api void MsgMeta::AddArray(const NDArray& array) {
static void SendData(network::Sender* sender, // We first write the ndim to the data_shape_
const char* data, data_shape_.push_back(static_cast<int64_t>(array->ndim));
int64_t size, // Then we write the data shape
int recv_id) { for (int i = 0; i < array->ndim; ++i) {
int64_t send_size = sender->Send(data, size, recv_id); data_shape_.push_back(array->shape[i]);
if (send_size <= 0) { }
LOG(FATAL) << "Send error (size: " << send_size << ")"; ndarray_count_++;
}
char* MsgMeta::Serialize(int64_t* size) {
char* buffer = nullptr;
int64_t buffer_size = 0;
buffer_size += sizeof(msg_type_);
if (ndarray_count_ != 0) {
buffer_size += sizeof(ndarray_count_);
buffer_size += sizeof(data_shape_.size());
buffer_size += sizeof(int64_t) * data_shape_.size();
}
buffer = new char[buffer_size];
char* pointer = buffer;
// Write msg_type_
*(reinterpret_cast<int*>(pointer)) = msg_type_;
pointer += sizeof(msg_type_);
if (ndarray_count_ != 0) {
// Write ndarray_count_
*(reinterpret_cast<int*>(pointer)) = ndarray_count_;
pointer += sizeof(ndarray_count_);
// Write size of data_shape_
*(reinterpret_cast<size_t*>(pointer)) = data_shape_.size();
pointer += sizeof(data_shape_.size());
// Write data of data_shape_
memcpy(pointer,
reinterpret_cast<char*>(data_shape_.data()),
sizeof(int64_t) * data_shape_.size());
} }
*size = buffer_size;
return buffer;
} }
// Wrapper for Recv api void MsgMeta::Deserialize(char* buffer, int64_t size) {
static void RecvData(network::Receiver* receiver, int64_t data_size = 0;
char* dest, // Read mesg_type_
int64_t max_size) { msg_type_ = *(reinterpret_cast<int*>(buffer));
int64_t recv_size = receiver->Recv(dest, max_size); buffer += sizeof(int);
if (recv_size <= 0) { data_size += sizeof(int);
LOG(FATAL) << "Receive error (size: " << recv_size << ")"; if (data_size < size) {
// Read ndarray_count_
ndarray_count_ = *(reinterpret_cast<int*>(buffer));
buffer += sizeof(int);
data_size += sizeof(int);
// Read size of data_shape_
size_t count = *(reinterpret_cast<size_t*>(buffer));
buffer += sizeof(size_t);
data_size += sizeof(size_t);
data_shape_.resize(count);
// Read data of data_shape_
memcpy(data_shape_.data(), buffer,
count * sizeof(int64_t));
data_size += count * sizeof(int64_t);
} }
CHECK_EQ(data_size, size);
} }
////////////////////////////////// Basic Networking Components ////////////////////////////////
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate") DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
network::Sender* sender = new network::SocketSender(); std::string type = args[0];
try { network::Sender* sender = nullptr;
char* buffer = new char[kMaxBufferSize]; if (type == "socket") {
sender->SetBuffer(buffer); sender = new network::SocketSender(kQueueSize);
} catch (const std::bad_alloc&) { } else {
LOG(FATAL) << "Not enough memory for sender buffer: " << kMaxBufferSize; LOG(FATAL) << "Unknown communicator type: " << type;
} }
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(sender); CommunicatorHandle chandle = static_cast<CommunicatorHandle>(sender);
*rv = chandle; *rv = chandle;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string type = args[0];
network::Receiver* receiver = nullptr;
if (type == "socket") {
receiver = new network::SocketReceiver(kQueueSize);
} else {
LOG(FATAL) << "Unknown communicator type: " << type;
}
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(receiver);
*rv = chandle;
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeSender") DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeSender")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
...@@ -59,6 +124,13 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeSender") ...@@ -59,6 +124,13 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeSender")
sender->Finalize(); sender->Finalize();
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
receiver->Finalize();
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver") DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
...@@ -66,7 +138,13 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver") ...@@ -66,7 +138,13 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver")
int port = args[2]; int port = args[2];
int recv_id = args[3]; int recv_id = args[3];
network::Sender* sender = static_cast<network::Sender*>(chandle); network::Sender* sender = static_cast<network::Sender*>(chandle);
sender->AddReceiver(ip.c_str(), port, recv_id); std::string addr;
if (sender->Type() == "socket") {
addr = StringPrintf("socket://%s:%d", ip.c_str(), port);
} else {
LOG(FATAL) << "Unknown communicator type: " << sender->Type();
}
sender->AddReceiver(addr.c_str(), recv_id);
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect") DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect")
...@@ -78,104 +156,218 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect") ...@@ -78,104 +156,218 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect")
} }
}); });
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph") DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
std::string ip = args[1];
int port = args[2];
int num_sender = args[3];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
std::string addr;
if (receiver->Type() == "socket") {
addr = StringPrintf("socket://%s:%d", ip.c_str(), port);
} else {
LOG(FATAL) << "Unknown communicator type: " << receiver->Type();
}
if (receiver->Wait(addr.c_str(), num_sender) == false) {
LOG(FATAL) << "Wait sender socket failed.";
}
});
////////////////////////// Distributed Sampler Components ////////////////////////////////
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendNodeFlow")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
int recv_id = args[1]; int recv_id = args[1];
// TODO(minjie): could simply use NodeFlow nf = args[2];
GraphRef g = args[2]; GraphRef g = args[2];
const IdArray node_mapping = args[3]; NDArray node_mapping = args[3];
const IdArray edge_mapping = args[4]; NDArray edge_mapping = args[4];
const IdArray layer_offsets = args[5]; NDArray layer_offsets = args[5];
const IdArray flow_offsets = args[6]; NDArray flow_offsets = args[6];
auto ptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr()); auto ptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(ptr) << "only immutable graph is allowed in send/recv"; CHECK(ptr) << "only immutable graph is allowed in send/recv";
network::Sender* sender = static_cast<network::Sender*>(chandle);
auto csr = ptr->GetInCSR(); auto csr = ptr->GetInCSR();
// Write control message // Create a message for the meta data of ndarray
char* buffer = sender->GetBuffer(); NDArray indptr = csr->indptr();
*buffer = CONTROL_NODEFLOW; NDArray indice = csr->indices();
// Serialize nodeflow to data buffer NDArray edge_ids = csr->edge_ids();
int64_t data_size = network::SerializeSampledSubgraph( MsgMeta msg(kNodeFlowMsg);
buffer+sizeof(CONTROL_NODEFLOW), msg.AddArray(node_mapping);
csr, msg.AddArray(edge_mapping);
node_mapping, msg.AddArray(layer_offsets);
edge_mapping, msg.AddArray(flow_offsets);
layer_offsets, msg.AddArray(indptr);
flow_offsets); msg.AddArray(indice);
CHECK_GT(data_size, 0); msg.AddArray(edge_ids);
data_size += sizeof(CONTROL_NODEFLOW); // send meta message
// Send msg via network int64_t size = 0;
SendData(sender, buffer, data_size, recv_id); char* data = msg.Serialize(&size);
network::Sender* sender = static_cast<network::Sender*>(chandle);
Message send_msg;
send_msg.data = data;
send_msg.size = size;
send_msg.deallocator = DefaultMessageDeleter;
CHECK_NE(sender->Send(send_msg, recv_id), -1);
// send node_mapping
Message node_mapping_msg;
node_mapping_msg.data = static_cast<char*>(node_mapping->data);
node_mapping_msg.size = node_mapping.GetSize();
node_mapping_msg.aux_handler = &node_mapping;
node_mapping_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(node_mapping_msg, recv_id), -1);
// send edege_mapping
Message edge_mapping_msg;
edge_mapping_msg.data = static_cast<char*>(edge_mapping->data);
edge_mapping_msg.size = edge_mapping.GetSize();
edge_mapping_msg.aux_handler = &edge_mapping;
edge_mapping_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(edge_mapping_msg, recv_id), -1);
// send layer_offsets
Message layer_offsets_msg;
layer_offsets_msg.data = static_cast<char*>(layer_offsets->data);
layer_offsets_msg.size = layer_offsets.GetSize();
layer_offsets_msg.aux_handler = &layer_offsets;
layer_offsets_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(layer_offsets_msg, recv_id), -1);
// send flow_offset
Message flow_offsets_msg;
flow_offsets_msg.data = static_cast<char*>(flow_offsets->data);
flow_offsets_msg.size = flow_offsets.GetSize();
flow_offsets_msg.aux_handler = &flow_offsets;
flow_offsets_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(flow_offsets_msg, recv_id), -1);
// send csr->indptr
Message indptr_msg;
indptr_msg.data = static_cast<char*>(indptr->data);
indptr_msg.size = indptr.GetSize();
indptr_msg.aux_handler = &indptr;
indptr_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(indptr_msg, recv_id), -1);
// send csr->indices
Message indices_msg;
indices_msg.data = static_cast<char*>(indice->data);
indices_msg.size = indice.GetSize();
indices_msg.aux_handler = &indice;
indices_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(indices_msg, recv_id), -1);
// send csr->edge_ids
Message edge_ids_msg;
edge_ids_msg.data = static_cast<char*>(csr->edge_ids()->data);
edge_ids_msg.size = csr->edge_ids().GetSize();
edge_ids_msg.aux_handler = &edge_ids;
edge_ids_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(edge_ids_msg, recv_id), -1);
}); });
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendEndSignal") DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSamplerEndSignal")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
int recv_id = args[1]; int recv_id = args[1];
MsgMeta msg(kEndMsg);
int64_t size = 0;
char* data = msg.Serialize(&size);
network::Sender* sender = static_cast<network::Sender*>(chandle); network::Sender* sender = static_cast<network::Sender*>(chandle);
char* buffer = sender->GetBuffer(); Message send_msg = {data, size};
*buffer = CONTROL_END_SIGNAL; send_msg.deallocator = DefaultMessageDeleter;
// Send msg via network CHECK_NE(sender->Send(send_msg, recv_id), -1);
SendData(sender, buffer, sizeof(CONTROL_END_SIGNAL), recv_id);
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
network::Receiver* receiver = new network::SocketReceiver();
try {
char* buffer = new char[kMaxBufferSize];
receiver->SetBuffer(buffer);
} catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for receiver buffer: " << kMaxBufferSize;
}
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(receiver);
*rv = chandle;
}); });
DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeReceiver") static void ConstructNFTensor(DLTensor *tensor, char* data, int64_t shape_0) {
.set_body([] (DGLArgs args, DGLRetValue* rv) { tensor->data = data;
CommunicatorHandle chandle = args[0]; tensor->ctx = DLContext{kDLCPU, 0};
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle); tensor->ndim = 1;
receiver->Finalize(); tensor->dtype = DLDataType{kDLInt, 64, 1};
}); tensor->shape = new int64_t[1];
tensor->shape[0] = shape_0;
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait") tensor->byte_offset = 0;
.set_body([] (DGLArgs args, DGLRetValue* rv) { }
CommunicatorHandle chandle = args[0];
std::string ip = args[1];
int port = args[2];
int num_sender = args[3];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
receiver->Wait(ip.c_str(), port, num_sender, kQueueSize);
});
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph") DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0]; CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle); network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
// Recv data from network int send_id = 0;
char* buffer = receiver->GetBuffer(); Message recv_msg;
RecvData(receiver, buffer, kMaxBufferSize); receiver->Recv(&recv_msg, &send_id);
int control = *buffer; MsgMeta msg(recv_msg.data, recv_msg.size);
if (control == CONTROL_NODEFLOW) { recv_msg.deallocator(&recv_msg);
if (msg.msg_type() == kNodeFlowMsg) {
CHECK_EQ(msg.ndarray_count() * 2, msg.data_shape_.size());
NodeFlow nf = NodeFlow::Create(); NodeFlow nf = NodeFlow::Create();
CSRPtr csr; // node_mapping
// Deserialize nodeflow from recv_data_buffer Message array_0;
network::DeserializeSampledSubgraph(buffer+sizeof(CONTROL_NODEFLOW), CHECK_NE(receiver->RecvFrom(&array_0, send_id), -1);
&(csr), CHECK_EQ(msg.data_shape_[0], 1);
&(nf->node_mapping), DLTensor node_mapping_tensor;
&(nf->edge_mapping), ConstructNFTensor(&node_mapping_tensor, array_0.data, msg.data_shape_[1]);
&(nf->layer_offsets), DLManagedTensor *node_mapping_managed_tensor = new DLManagedTensor();
&(nf->flow_offsets)); node_mapping_managed_tensor->dl_tensor = node_mapping_tensor;
nf->node_mapping = NDArray::FromDLPack(node_mapping_managed_tensor);
// edge_mapping
Message array_1;
CHECK_NE(receiver->RecvFrom(&array_1, send_id), -1);
CHECK_EQ(msg.data_shape_[2], 1);
DLTensor edge_mapping_tensor;
ConstructNFTensor(&edge_mapping_tensor, array_1.data, msg.data_shape_[3]);
DLManagedTensor *edge_mapping_managed_tensor = new DLManagedTensor();
edge_mapping_managed_tensor->dl_tensor = edge_mapping_tensor;
nf->edge_mapping = NDArray::FromDLPack(edge_mapping_managed_tensor);
// layer_offset
Message array_2;
CHECK_NE(receiver->RecvFrom(&array_2, send_id), -1);
CHECK_EQ(msg.data_shape_[4], 1);
DLTensor layer_offsets_tensor;
ConstructNFTensor(&layer_offsets_tensor, array_2.data, msg.data_shape_[5]);
DLManagedTensor *layer_offsets_managed_tensor = new DLManagedTensor();
layer_offsets_managed_tensor->dl_tensor = layer_offsets_tensor;
nf->layer_offsets = NDArray::FromDLPack(layer_offsets_managed_tensor);
// flow_offset
Message array_3;
CHECK_NE(receiver->RecvFrom(&array_3, send_id), -1);
CHECK_EQ(msg.data_shape_[6], 1);
DLTensor flow_offsets_tensor;
ConstructNFTensor(&flow_offsets_tensor, array_3.data, msg.data_shape_[7]);
DLManagedTensor *flow_offsets_managed_tensor = new DLManagedTensor();
flow_offsets_managed_tensor->dl_tensor = flow_offsets_tensor;
nf->flow_offsets = NDArray::FromDLPack(flow_offsets_managed_tensor);
// CSR indptr
Message array_4;
CHECK_NE(receiver->RecvFrom(&array_4, send_id), -1);
CHECK_EQ(msg.data_shape_[8], 1);
DLTensor indptr_tensor;
ConstructNFTensor(&indptr_tensor, array_4.data, msg.data_shape_[9]);
DLManagedTensor *indptr_managed_tensor = new DLManagedTensor();
indptr_managed_tensor->dl_tensor = indptr_tensor;
NDArray indptr = NDArray::FromDLPack(indptr_managed_tensor);
// CSR indice
Message array_5;
CHECK_NE(receiver->RecvFrom(&array_5, send_id), -1);
CHECK_EQ(msg.data_shape_[10], 1);
DLTensor indice_tensor;
ConstructNFTensor(&indice_tensor, array_5.data, msg.data_shape_[11]);
DLManagedTensor *indice_managed_tensor = new DLManagedTensor();
indice_managed_tensor->dl_tensor = indice_tensor;
NDArray indice = NDArray::FromDLPack(indice_managed_tensor);
// CSR edge_ids
Message array_6;
CHECK_NE(receiver->RecvFrom(&array_6, send_id), -1);
CHECK_EQ(msg.data_shape_[12], 1);
DLTensor edge_id_tensor;
ConstructNFTensor(&edge_id_tensor, array_6.data, msg.data_shape_[13]);
DLManagedTensor *edge_id_managed_tensor = new DLManagedTensor();
edge_id_managed_tensor->dl_tensor = edge_id_tensor;
NDArray edge_ids = NDArray::FromDLPack(edge_id_managed_tensor);
// Create CSR
CSRPtr csr(new CSR(indptr, indice, edge_ids));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr)); nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr));
List<NodeFlow> subgs; *rv = nf;
subgs.push_back(nf); } else if (msg.msg_type() == kEndMsg) {
*rv = subgs; *rv = msg.msg_type();
} else if (control == CONTROL_END_SIGNAL) {
*rv = CONTROL_END_SIGNAL;
} else { } else {
LOG(FATAL) << "Unknow control number: " << control; LOG(FATAL) << "Unknown message type: " << msg.msg_type();
} }
}); });
......
...@@ -7,25 +7,117 @@ ...@@ -7,25 +7,117 @@
#define DGL_GRAPH_NETWORK_H_ #define DGL_GRAPH_NETWORK_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dgl/runtime/ndarray.h>
#include <string.h>
#include <vector>
#include "../c_api_common.h"
#include "./network/msg_queue.h"
using dgl::runtime::NDArray;
namespace dgl { namespace dgl {
namespace network { namespace network {
#define IS_SENDER true // Max size of message queue for communicator is 200 MB
#define IS_RECEIVER false // TODO(chao): Make this number configurable
const int64_t kQueueSize = 200 * 1024 * 1024;
/*!
* \brief Free memory buffer of NodeFlow
*/
inline void NDArrayDeleter(Message* msg) {
delete reinterpret_cast<NDArray*>(msg->aux_handler);
}
/*!
* \brief Message type for DGL distributed training
*/
enum MessageType {
/*!
* \brief Message for send/recv NodeFlow
*/
kNodeFlowMsg = 0,
/*!
* \brief Message for end-signal
*/
kEndMsg = 1
};
/*!
* \brief Meta data for communicator message
*/
class MsgMeta {
public:
/*!
* \brief MsgMeta constructor.
* \param msg_type type of message
*/
explicit MsgMeta(int msg_type)
: msg_type_(msg_type), ndarray_count_(0) {}
/*!
* \brief Construct MsgMeta from binary data buffer.
* \param buffer data buffer
* \param size data size
*/
MsgMeta(char* buffer, int64_t size) {
CHECK_NOTNULL(buffer);
this->Deserialize(buffer, size);
}
/*!
* \return message type
*/
inline int msg_type() const {
return msg_type_;
}
/*!
* \return count of ndarray
*/
inline int ndarray_count() const {
return ndarray_count_;
}
/*!
* \brief Add NDArray meta data to MsgMeta
* \param array DGL NDArray
*/
void AddArray(const NDArray& array);
/*!
* \brief Serialize MsgMeta to data buffer
* \param size size of serialized message
* \return pointer of data buffer
*/
char* Serialize(int64_t* size);
/*!
* \brief Deserialize MsgMeta from data buffer
* \param buffer data buffer
* \param size size of data buffer
*/
void Deserialize(char* buffer, int64_t size);
/*!
* \brief type of message
*/
int msg_type_;
// TODO(chao): make these numbers configurable /*!
* \brief count of ndarray in MetaMsg
*/
int ndarray_count_;
// Each single message cannot larger than 300 MB /*!
const int64_t kMaxBufferSize = 300 * 1024 * 2014; * \brief We first write the ndim to data_shape_
// Size of message queue is 1 GB * and then write the data shape.
const int64_t kQueueSize = 1024 * 1024 * 1024; */
// Maximal try count of connection std::vector<int64_t> data_shape_;
const int kMaxTryCount = 500; };
// Control number
const int CONTROL_NODEFLOW = 0;
const int CONTROL_END_SIGNAL = 1;
} // namespace network } // namespace network
} // namespace dgl } // namespace dgl
......
/*!
* Copyright (c) 2019 by Contributors
* \file common.cc
* \brief This file provide basic facilities for string
* to make programming convenient.
*/
#include "common.h"
#include <stdarg.h>
#include <stdio.h>
using std::string;
namespace dgl {
namespace network {
// In most cases, delim contains only one character. In this case, we
// use CalculateReserveForVector to count the number of elements should
// be reserved in result vector, and thus optimize SplitStringUsing.
static int CalculateReserveForVector(const std::string& full, const char* delim) {
int count = 0;
if (delim[0] != '\0' && delim[1] == '\0') {
// Optimize the common case where delim is a single character.
char c = delim[0];
const char* p = full.data();
const char* end = p + full.size();
while (p != end) {
if (*p == c) { // This could be optimized with hasless(v,1) trick.
++p;
} else {
while (++p != end && *p != c) {
// Skip to the next occurence of the delimiter.
}
++count;
}
}
}
return count;
}
void SplitStringUsing(const std::string& full,
const char* delim,
std::vector<std::string>* result) {
CHECK(delim != NULL);
CHECK(result != NULL);
result->reserve(CalculateReserveForVector(full, delim));
back_insert_iterator< std::vector<std::string> > it(*result);
SplitStringToIteratorUsing(full, delim, &it);
}
void SplitStringToSetUsing(const std::string& full,
const char* delim,
std::set<std::string>* result) {
CHECK(delim != NULL);
CHECK(result != NULL);
simple_insert_iterator<std::set<std::string> > it(result);
SplitStringToIteratorUsing(full, delim, &it);
}
static void StringAppendV(string* dst, const char* format, va_list ap) {
// First try with a small fixed size buffer
char space[1024];
// It's possible for methods that use a va_list to invalidate
// the data in it upon use. The fix is to make a copy
// of the structure before using it and use that copy instead.
va_list backup_ap;
va_copy(backup_ap, ap);
int result = vsnprintf(space, sizeof(space), format, backup_ap);
va_end(backup_ap);
if ((result >= 0) && (result < sizeof(space))) {
// It fit
dst->append(space, result);
return;
}
// Repeatedly increase buffer size until it fits
int length = sizeof(space);
while (true) {
if (result < 0) {
// Older behavior: just try doubling the buffer size
length *= 2;
} else {
// We need exactly "result+1" characters
length = result + 1;
}
char* buf = new char[length];
// Restore the va_list before we use it again
va_copy(backup_ap, ap);
result = vsnprintf(buf, length, format, backup_ap);
va_end(backup_ap);
if ((result >= 0) && (result < length)) {
// It fit
dst->append(buf, result);
delete[] buf;
return;
}
delete[] buf;
}
}
string StringPrintf(const char* format, ...) {
va_list ap;
va_start(ap, format);
string result;
StringAppendV(&result, format, ap);
va_end(ap);
return result;
}
void SStringPrintf(string* dst, const char* format, ...) {
va_list ap;
va_start(ap, format);
dst->clear();
StringAppendV(dst, format, ap);
va_end(ap);
}
void StringAppendF(string* dst, const char* format, ...) {
va_list ap;
va_start(ap, format);
StringAppendV(dst, format, ap);
va_end(ap);
}
} // namespace network
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file common.h
* \brief This file provide basic facilities for string
* to make programming convenient.
*/
#ifndef DGL_GRAPH_NETWORK_COMMON_H_
#define DGL_GRAPH_NETWORK_COMMON_H_
#include <dmlc/logging.h>
#include <set>
#include <string>
#include <vector>
namespace dgl {
namespace network {
//------------------------------------------------------------------------------
// Subdivide string |full| into substrings according to delimitors
// given in |delim|. |delim| should pointing to a string including
// one or more characters. Each character is considerred a possible
// delimitor. For example:
//
// vector<string> substrings;
// SplitStringUsing("apple orange\tbanana", "\t ", &substrings);
//
// results in three substrings:
//
// substrings.size() == 3
// substrings[0] == "apple"
// substrings[1] == "orange"
// substrings[2] == "banana"
//------------------------------------------------------------------------------
void SplitStringUsing(const std::string& full,
const char* delim,
std::vector<std::string>* result);
// This function has the same semnatic as SplitStringUsing. Results
// are saved in an STL set container.
void SplitStringToSetUsing(const std::string& full,
const char* delim,
std::set<std::string>* result);
template <typename T>
struct simple_insert_iterator {
explicit simple_insert_iterator(T* t) : t_(t) { }
simple_insert_iterator<T>& operator=(const typename T::value_type& value) {
t_->insert(value);
return *this;
}
simple_insert_iterator<T>& operator*() { return *this; }
simple_insert_iterator<T>& operator++() { return *this; }
simple_insert_iterator<T>& operator++(int placeholder) { return *this; }
T* t_;
};
template <typename T>
struct back_insert_iterator {
explicit back_insert_iterator(T& t) : t_(t) {}
back_insert_iterator<T>& operator=(const typename T::value_type& value) {
t_.push_back(value);
return *this;
}
back_insert_iterator<T>& operator*() { return *this; }
back_insert_iterator<T>& operator++() { return *this; }
back_insert_iterator<T> operator++(int placeholder) { return *this; }
T& t_;
};
template <typename StringType, typename ITR>
static inline
void SplitStringToIteratorUsing(const StringType& full,
const char* delim,
ITR* result) {
CHECK_NOTNULL(delim);
// Optimize the common case where delim is a single character.
if (delim[0] != '\0' && delim[1] == '\0') {
char c = delim[0];
const char* p = full.data();
const char* end = p + full.size();
while (p != end) {
if (*p == c) {
++p;
} else {
const char* start = p;
while (++p != end && *p != c) {
// Skip to the next occurence of the delimiter.
}
*(*result)++ = StringType(start, p - start);
}
}
return;
}
std::string::size_type begin_index, end_index;
begin_index = full.find_first_not_of(delim);
while (begin_index != std::string::npos) {
end_index = full.find_first_of(delim, begin_index);
if (end_index == std::string::npos) {
*(*result)++ = full.substr(begin_index);
return;
}
*(*result)++ = full.substr(begin_index, (end_index - begin_index));
begin_index = full.find_first_not_of(delim, end_index);
}
}
//------------------------------------------------------------------------------
// StringPrintf:
//
// For example:
//
// std::string str = StringPrintf("%d", 1); /* str = "1" */
// SStringPrintf(&str, "%d", 2); /* str = "2" */
// StringAppendF(&str, "%d", 3); /* str = "23" */
//------------------------------------------------------------------------------
std::string StringPrintf(const char* format, ...);
void SStringPrintf(std::string* dst, const char* format, ...);
void StringAppendF(std::string* dst, const char* format, ...);
} // namespace network
} // namespace dgl
#endif // DGL_GRAPH_NETWORK_COMMON_H_
...@@ -6,112 +6,165 @@ ...@@ -6,112 +6,165 @@
#ifndef DGL_GRAPH_NETWORK_COMMUNICATOR_H_ #ifndef DGL_GRAPH_NETWORK_COMMUNICATOR_H_
#define DGL_GRAPH_NETWORK_COMMUNICATOR_H_ #define DGL_GRAPH_NETWORK_COMMUNICATOR_H_
#include <dmlc/logging.h>
#include <string> #include <string>
#include "msg_queue.h"
namespace dgl { namespace dgl {
namespace network { namespace network {
/*! /*!
* \brief Network Sender for DGL distributed training. * \brief Network Sender for DGL distributed training.
* *
* Sender is an abstract class that defines a set of APIs for sending * Sender is an abstract class that defines a set of APIs for sending binary
* binary data over network. It can be implemented by different underlying * data message over network. It can be implemented by different underlying
* networking libraries such TCP socket and ZMQ. One Sender can connect to * networking libraries such TCP socket and MPI. One Sender can connect to
* multiple receivers, and it can send data to specified receiver via receiver's ID. * multiple receivers and it can send data to specified receiver via receiver's ID.
*/ */
class Sender { class Sender {
public: public:
/*!
* \brief Sender constructor
* \param queue_size size (bytes) of message queue.
* Note that, the queue_size parameter is optional.
*/
explicit Sender(int64_t queue_size = 0) {
CHECK_GE(queue_size, 0);
queue_size_ = queue_size;
}
virtual ~Sender() {} virtual ~Sender() {}
/*! /*!
* \brief Add receiver address and it's ID to the namebook * \brief Add receiver's address and ID to the sender's namebook
* \param ip receviver's IP address * \param addr Networking address, e.g., 'socket://127.0.0.1:50091', 'mpi://0'
* \param port receiver's port
* \param id receiver's ID * \param id receiver's ID
*
* AddReceiver() is not thread-safe and only one thread can invoke this API.
*/ */
virtual void AddReceiver(const char* ip, int port, int id) = 0; virtual void AddReceiver(const char* addr, int id) = 0;
/*! /*!
* \brief Connect with all the Receivers * \brief Connect with all the Receivers
* \return True for sucess and False for fail * \return True for success and False for fail
*
* Connect() is not thread-safe and only one thread can invoke this API.
*/ */
virtual bool Connect() = 0; virtual bool Connect() = 0;
/*! /*!
* \brief Send data to specified Receiver * \brief Send data to specified Receiver.
* \param data data buffer for sending * \param msg data message
* \param size data size for sending
* \param recv_id receiver's ID * \param recv_id receiver's ID
* \return bytes we sent * \return Status code
* > 0 : bytes we sent *
* - 1 : error * (1) The send is non-blocking. There is no guarantee that the message has been
* physically sent out when the function returns.
* (2) The communicator will assume the responsibility of the given message.
* (3) The API is multi-thread safe.
* (4) Messages sent to the same receiver are guaranteed to be received in the same order.
* There is no guarantee for messages sent to different receivers.
*/ */
virtual int64_t Send(const char* data, int64_t size, int recv_id) = 0; virtual STATUS Send(Message msg, int recv_id) = 0;
/*! /*!
* \brief Finalize Sender * \brief Finalize Sender
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/ */
virtual void Finalize() = 0; virtual void Finalize() = 0;
/*! /*!
* \brief Get data buffer * \brief Communicator type: 'socket', 'mpi', etc.
* \return buffer pointer
*/ */
virtual char* GetBuffer() = 0; virtual std::string Type() const = 0;
protected:
/*! /*!
* \brief Set data buffer * \brief Size of message queue
*/ */
virtual void SetBuffer(char* buffer) = 0; int64_t queue_size_;
}; };
/*! /*!
* \brief Network Receiver for DGL distributed training. * \brief Network Receiver for DGL distributed training.
* *
* Receiver is an abstract class that defines a set of APIs for receiving binary * Receiver is an abstract class that defines a set of APIs for receiving binary data
* data over network. It can be implemented by different underlying networking libraries * message over network. It can be implemented by different underlying networking
* such TCP socket and ZMQ. One Receiver can connect with multiple Senders, and it can receive * libraries such as TCP socket and MPI. One Receiver can connect with multiple Senders
* data from these Senders concurrently via multi-threading and message queue. * and it can receive data from multiple Senders concurrently.
*/ */
class Receiver { class Receiver {
public: public:
/*!
* \brief Receiver constructor
* \param queue_size size of message queue.
* Note that, the queue_size parameter is optional.
*/
explicit Receiver(int64_t queue_size = 0) {
if (queue_size < 0) {
LOG(FATAL) << "queue_size cannot be a negative number.";
}
queue_size_ = queue_size;
}
virtual ~Receiver() {} virtual ~Receiver() {}
/*! /*!
* \brief Wait all of the Senders to connect * \brief Wait for all the Senders to connect
* \param ip Receiver's IP address * \param addr Networking address, e.g., 'socket://127.0.0.1:50051', 'mpi://0'
* \param port Receiver's port
* \param num_sender total number of Senders * \param num_sender total number of Senders
* \param queue_size size of message queue * \return True for success and False for fail
* \return True for sucess and False for fail *
* Wait() is not thread-safe and only one thread can invoke this API.
*/
virtual bool Wait(const char* addr, int num_sender) = 0;
/*!
* \brief Recv data from Sender
* \param msg pointer of data message
* \param send_id which sender current msg comes from
* \return Status code
*
* (1) The Recv() API is blocking, which will not
* return until getting data from message queue.
* (2) The Recv() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/ */
virtual bool Wait(const char* ip, int port, int num_sender, int queue_size) = 0; virtual STATUS Recv(Message* msg, int* send_id) = 0;
/*! /*!
* \brief Recv data from Sender (copy data from message queue) * \brief Recv data from a specified Sender
* \param dest data buffer of destination * \param msg pointer of data message
* \param max_size maximul size of data buffer * \param send_id sender's ID
* \return bytes we received * \return Status code
* > 0 : bytes we received *
* - 1 : error * (1) The RecvFrom() API is blocking, which will not
* return until getting data from message queue.
* (2) The RecvFrom() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/ */
virtual int64_t Recv(char* dest, int64_t max_size) = 0; virtual STATUS RecvFrom(Message* msg, int send_id) = 0;
/*! /*!
* \brief Finalize Receiver * \brief Finalize Receiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/ */
virtual void Finalize() = 0; virtual void Finalize() = 0;
/*! /*!
* \brief Get data buffer * \brief Communicator type: 'socket', 'mpi', etc
* \return buffer pointer
*/ */
virtual char* GetBuffer() = 0; virtual std::string Type() const = 0;
protected:
/*! /*!
* \brief Set data buffer * \brief Size of message queue
*/ */
virtual void SetBuffer(char* buffer) = 0; int64_t queue_size_;
}; };
} // namespace network } // namespace network
......
...@@ -14,162 +14,75 @@ namespace network { ...@@ -14,162 +14,75 @@ namespace network {
using std::string; using std::string;
MessageQueue::MessageQueue(int64_t queue_size, int num_producers) { MessageQueue::MessageQueue(int64_t queue_size, int num_producers) {
CHECK_LT(0, queue_size); CHECK_GE(queue_size, 0);
try { CHECK_GE(num_producers, 0);
queue_ = new char[queue_size];
} catch(const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for message queue.";
}
memset(queue_, '\0', queue_size);
queue_size_ = queue_size; queue_size_ = queue_size;
free_size_ = queue_size; free_size_ = queue_size;
write_pointer_ = 0;
num_producers_ = num_producers; num_producers_ = num_producers;
} }
MessageQueue::~MessageQueue() { STATUS MessageQueue::Add(Message msg, bool is_blocking) {
std::lock_guard<std::mutex> lock(mutex_);
if (nullptr != queue_) {
delete [] queue_;
queue_ = nullptr;
}
}
int64_t MessageQueue::Add(const char* src, int64_t size, bool is_blocking) {
// check if message is too long to fit into the queue // check if message is too long to fit into the queue
if (size > queue_size_) { if (msg.size > queue_size_) {
LOG(ERROR) << "Message is larger than the queue."; LOG(WARNING) << "Message is larger than the queue.";
return -1; return MSG_GT_SIZE;
} }
if (size <= 0) { if (msg.size <= 0) {
LOG(ERROR) << "Message size (" << size << ") is negative or zero."; LOG(WARNING) << "Message size (" << msg.size << ") is negative or zero.";
return -1; return MSG_LE_ZERO;
} }
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (finished_producers_.size() >= num_producers_) { if (finished_producers_.size() >= num_producers_) {
LOG(ERROR) << "Can't add to buffer when flag_no_more is set."; LOG(WARNING) << "Message queue is closed.";
return -1; return QUEUE_CLOSE;
} }
if (size > free_size_ && !is_blocking) { if (msg.size > free_size_ && !is_blocking) {
LOG(WARNING) << "Queue is full and message lost."; return QUEUE_FULL;
return 0;
} }
cond_not_full_.wait(lock, [&]() { cond_not_full_.wait(lock, [&]() {
return size <= free_size_; return msg.size <= free_size_;
}); });
// Write data into buffer: // Add data pointer to queue
// If there has enough space on tail of buffer, just append data queue_.push(msg);
// else, write till in the end of buffer and return to head of buffer free_size_ -= msg.size;
message_positions_.push(std::make_pair(write_pointer_, size));
free_size_ -= size;
if (write_pointer_ + size <= queue_size_) {
memcpy(&queue_[write_pointer_], src, size);
write_pointer_ += size;
if (write_pointer_ == queue_size_) {
write_pointer_ = 0;
}
} else {
int64_t size_partial = queue_size_ - write_pointer_;
memcpy(&queue_[write_pointer_], src, size_partial);
memcpy(queue_, &src[size_partial], size - size_partial);
write_pointer_ = size - size_partial;
}
// not empty signal // not empty signal
cond_not_empty_.notify_one(); cond_not_empty_.notify_one();
return size; return ADD_SUCCESS;
}
int64_t MessageQueue::Add(const string &src, bool is_blocking) {
return Add(src.data(), src.size(), is_blocking);
}
int64_t MessageQueue::Remove(char *dest, int64_t max_size, bool is_blocking) {
int64_t retval;
std::unique_lock<std::mutex> lock(mutex_);
if (message_positions_.empty()) {
if (!is_blocking) {
return 0;
}
if (finished_producers_.size() >= num_producers_) {
return 0;
}
}
cond_not_empty_.wait(lock, [this] {
return !message_positions_.empty() || exit_flag_.load();
});
if (finished_producers_.size() >= num_producers_) {
return 0;
}
MessagePosition & pos = message_positions_.front();
// check if message is too long
if (pos.second > max_size) {
LOG(ERROR) << "Message size exceeds limit, information lost.";
retval = -1;
} else {
// read from buffer:
// if this message stores in consecutive memory, just read
// else, read from buffer tail then return to the head
if (pos.first + pos.second <= queue_size_) {
memcpy(dest, &queue_[pos.first], pos.second);
} else {
int64_t size_partial = queue_size_ - pos.first;
memcpy(dest, &queue_[pos.first], size_partial);
memcpy(&dest[size_partial], queue_, pos.second - size_partial);
}
retval = pos.second;
}
free_size_ += pos.second;
message_positions_.pop();
cond_not_full_.notify_one();
return retval;
} }
int64_t MessageQueue::Remove(string *dest, bool is_blocking) { STATUS MessageQueue::Remove(Message* msg, bool is_blocking) {
int64_t retval;
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (message_positions_.empty()) { if (queue_.empty()) {
if (!is_blocking) { if (!is_blocking) {
return 0; return QUEUE_EMPTY;
} }
if (finished_producers_.size() >= num_producers_) { if (finished_producers_.size() >= num_producers_) {
return 0; LOG(WARNING) << "Message queue is closed.";
return QUEUE_CLOSE;
} }
} }
cond_not_empty_.wait(lock, [this] { cond_not_empty_.wait(lock, [this] {
return !message_positions_.empty() || exit_flag_.load(); return !queue_.empty() || exit_flag_.load();
}); });
if (finished_producers_.size() >= num_producers_ && queue_.empty()) {
MessagePosition & pos = message_positions_.front(); LOG(WARNING) << "Message queue is closed.";
// read from buffer: return QUEUE_CLOSE;
// if this message stores in consecutive memory, just read
// else, read from buffer tail then return to the head
if (pos.first + pos.second <= queue_size_) {
dest->assign(&queue_[pos.first], pos.second);
} else {
int64_t size_partial = queue_size_ - pos.first;
dest->assign(&queue_[pos.first], size_partial);
dest->append(queue_, pos.second - size_partial);
} }
retval = pos.second;
free_size_ += pos.second;
message_positions_.pop();
Message & old_msg = queue_.front();
queue_.pop();
msg->data = old_msg.data;
msg->size = old_msg.size;
msg->deallocator = old_msg.deallocator;
free_size_ += old_msg.size;
cond_not_full_.notify_one(); cond_not_full_.notify_one();
return retval; return REMOVE_SUCCESS;
} }
void MessageQueue::Signal(int producer_id) { void MessageQueue::SignalFinished(int producer_id) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
finished_producers_.insert(producer_id); finished_producers_.insert(producer_id);
// if all producers have finished, consumers should be // if all producers have finished, consumers should be
...@@ -180,9 +93,14 @@ void MessageQueue::Signal(int producer_id) { ...@@ -180,9 +93,14 @@ void MessageQueue::Signal(int producer_id) {
} }
} }
bool MessageQueue::Empty() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size() == 0;
}
bool MessageQueue::EmptyAndNoMoreAdd() const { bool MessageQueue::EmptyAndNoMoreAdd() const {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return message_positions_.size() == 0 && return queue_.size() == 0 &&
finished_producers_.size() >= num_producers_; finished_producers_.size() >= num_producers_;
} }
......
...@@ -13,30 +13,82 @@ ...@@ -13,30 +13,82 @@
#include <mutex> #include <mutex>
#include <condition_variable> #include <condition_variable>
#include <atomic> #include <atomic>
#include <functional>
namespace dgl { namespace dgl {
namespace network { namespace network {
typedef int STATUS;
/*!
* \brief Status code of message queue
*/
#define ADD_SUCCESS 3400 // Add message successfully
#define MSG_GT_SIZE 3401 // Message size beyond queue size
#define MSG_LE_ZERO 3402 // Message size is not a positive number
#define QUEUE_CLOSE 3403 // Cannot add message when queue is closed
#define QUEUE_FULL 3404 // Cannot add message when queue is full
#define REMOVE_SUCCESS 3405 // Remove message successfully
#define QUEUE_EMPTY 3406 // Cannot remove when queue is empty
/*!
* \brief Message used by network communicator and message queue.
*/
struct Message {
/*!
* \brief Constructor
*/
Message() { }
/*!
* \brief Constructor
*/
Message(char* data_ptr, int64_t data_size)
: data(data_ptr), size(data_size) { }
/*!
* \brief message data
*/
char* data;
/*!
* \brief message size in bytes
*/
int64_t size;
/*!
* \brief aux_data pointer handler
*/
void* aux_handler;
/*!
* \brief user-defined deallocator, which can be nullptr
*/
std::function<void(Message*)> deallocator = nullptr;
};
/*! /*!
* \brief Message Queue for DGL distributed training. * \brief Free memory buffer of message
*/
inline void DefaultMessageDeleter(Message* msg) { delete [] msg->data; }
/*!
* \brief Message Queue for network communication.
*
* MessageQueue is FIFO queue that adopts producer/consumer model for data message.
* It supports one or more producer threads and one or more consumer threads.
* Producers invokes Add() to push data message into the queue, and consumers
* invokes Remove() to pop data message from queue. Add() and Remove() use two condition
* variables to synchronize producer threads and consumer threads. Each producer
* invokes SignalFinished(producer_id) to claim that it is about to finish, where
* producer_id is an integer uniquely identify a producer thread. This signaling mechanism
* prevents consumers from waiting after all producers have finished their jobs.
* *
* MessageQueue is a circle queue for using the ring-buffer in a * MessageQueue is thread-safe.
* producer/consumer model. It supports one or more producer
* threads and one or more consumer threads. Producers invokes Add()
* to push data elements into the queue, and consumers invokes
* Remove() to pop data elements. Add() and Remove() use two condition
* variables to synchronize producers and consumers. Each producer invokes
* Signal(producer_id) to claim that it is about to finish, where
* producer_id is an integer uniquely identify a producer thread. This
* signaling mechanism prevents consumers from waiting after all producers
* have finished their jobs.
* *
*/ */
class MessageQueue { class MessageQueue {
public: public:
/*! /*!
* \brief MessageQueue constructor * \brief MessageQueue constructor
* \param queue_size size of message queue * \param queue_size size (bytes) of message queue
* \param num_producers number of producers, use 1 by default * \param num_producers number of producers, use 1 by default
*/ */
MessageQueue(int64_t queue_size /* in bytes */, MessageQueue(int64_t queue_size /* in bytes */,
...@@ -45,59 +97,34 @@ class MessageQueue { ...@@ -45,59 +97,34 @@ class MessageQueue {
/*! /*!
* \brief MessageQueue deconstructor * \brief MessageQueue deconstructor
*/ */
~MessageQueue(); ~MessageQueue() {}
/*! /*!
* \brief Add data to the message queue * \brief Add message to the queue
* \param src The data pointer * \param msg data message
* \param size The size of data * \param is_blocking Blocking if cannot add, else return
* \param is_blocking Block function if cannot add, else return * \return Status code
* \return bytes added to the queue
* > 0 : size of message
* = 0 : no enough space for this message (when is_blocking = false)
* - 1 : error
*/ */
int64_t Add(const char* src, int64_t size, bool is_blocking = true); STATUS Add(Message msg, bool is_blocking = true);
/*!
* \brief Add data to the message queue
* \param src The data string
* \param is_blocking Block function if cannot add, else return
* \return bytes added to queue
* > 0 : size of message
* = 0 : no enough space for this message (when is_blocking = false)
* - 1 : error
*/
int64_t Add(const std::string& src, bool is_blocking = true);
/*! /*!
* \brief Remove message from the queue * \brief Remove message from the queue
* \param dest The destination data pointer * \param msg pointer of data msg
* \param max_size Maximal size of data * \param is_blocking Blocking if cannot remove, else return
* \param is_blocking Block function if cannot remove, else return * \return Status code
* \return bytes removed from queue
* > 0 : size of message
* = 0 : queue is empty
* - 1 : error
*/ */
int64_t Remove(char *dest, int64_t max_size, bool is_blocking = true); STATUS Remove(Message* msg, bool is_blocking = true);
/*! /*!
* \brief Remove message from the queue * \brief Signal that producer producer_id will no longer produce anything
* \param dest The destination data string * \param producer_id An integer uniquely to identify a producer thread
* \param is_blocking Block function if cannot remove, else return
* \return bytes removed from queue
* > 0 : size of message
* = 0 : queue is empty
* - 1 : error
*/ */
int64_t Remove(std::string *dest, bool is_blocking = true); void SignalFinished(int producer_id);
/*! /*!
* \brief Signal that producer producer_id will no longer produce anything * \return true if queue is empty.
* \param producer_id An integer uniquely to identify a producer thread
*/ */
void Signal(int producer_id); bool Empty() const;
/*! /*!
* \return true if queue is empty and all num_producers have signaled. * \return true if queue is empty and all num_producers have signaled.
...@@ -105,13 +132,10 @@ class MessageQueue { ...@@ -105,13 +132,10 @@ class MessageQueue {
bool EmptyAndNoMoreAdd() const; bool EmptyAndNoMoreAdd() const;
protected: protected:
typedef std::pair<int64_t /* message_start_position in queue_ */,
int64_t /* message_length */> MessagePosition;
/*! /*!
* \brief Pointer to the queue * \brief message queue
*/ */
char* queue_; std::queue<Message> queue_;
/*! /*!
* \brief Size of the queue in bytes * \brief Size of the queue in bytes
...@@ -123,24 +147,11 @@ class MessageQueue { ...@@ -123,24 +147,11 @@ class MessageQueue {
*/ */
int64_t free_size_; int64_t free_size_;
/*!
* \brief Location in queue_ for where to write the next element
* Note that we do not need read_pointer since all messages were indexed
* by message_postions_, and the first element in message_position_
* denotes where we should read
*/
int64_t write_pointer_;
/*! /*!
* \brief Used to check all producers will no longer produce anything * \brief Used to check all producers will no longer produce anything
*/ */
size_t num_producers_; size_t num_producers_;
/*!
* \brief Messages in the queue
*/
std::queue<MessagePosition> message_positions_;
/*! /*!
* \brief Store finished producer id * \brief Store finished producer id
*/ */
......
/*!
* Copyright (c) 2019 by Contributors
* \file serialize.cc
* \brief Serialization for DGL distributed training.
*/
#include "serialize.h"
#include <dmlc/logging.h>
#include <dgl/immutable_graph.h>
#include <cstring>
#include "../network.h"
namespace dgl {
namespace network {
const int kNumTensor = 7; // We need to serialize 7 conponents (tensor) here
int64_t SerializeSampledSubgraph(char* data,
const CSRPtr csr,
const IdArray& node_mapping,
const IdArray& edge_mapping,
const IdArray& layer_offsets,
const IdArray& flow_offsets) {
int64_t total_size = 0;
// For each component, we first write its size at the
// begining of the buffer and then write its binary data
int64_t node_mapping_size = node_mapping->shape[0] * sizeof(dgl_id_t);
int64_t edge_mapping_size = edge_mapping->shape[0] * sizeof(dgl_id_t);
int64_t layer_offsets_size = layer_offsets->shape[0] * sizeof(dgl_id_t);
int64_t flow_offsets_size = flow_offsets->shape[0] * sizeof(dgl_id_t);
int64_t indptr_size = csr->indptr().GetSize();
int64_t indices_size = csr->indices().GetSize();
int64_t edge_ids_size = csr->edge_ids().GetSize();
total_size += node_mapping_size;
total_size += edge_mapping_size;
total_size += layer_offsets_size;
total_size += flow_offsets_size;
total_size += indptr_size;
total_size += indices_size;
total_size += edge_ids_size;
total_size += kNumTensor * sizeof(int64_t);
if (total_size > kMaxBufferSize) {
LOG(FATAL) << "Message size: (" << total_size
<< ") is larger than buffer size: ("
<< kMaxBufferSize << ")";
}
// Write binary data to buffer
char* data_ptr = data;
dgl_id_t* node_map_data = static_cast<dgl_id_t*>(node_mapping->data);
dgl_id_t* edge_map_data = static_cast<dgl_id_t*>(edge_mapping->data);
dgl_id_t* layer_off_data = static_cast<dgl_id_t*>(layer_offsets->data);
dgl_id_t* flow_off_data = static_cast<dgl_id_t*>(flow_offsets->data);
dgl_id_t* indptr = static_cast<dgl_id_t*>(csr->indptr()->data);
dgl_id_t* indices = static_cast<dgl_id_t*>(csr->indices()->data);
dgl_id_t* edge_ids = static_cast<dgl_id_t*>(csr->edge_ids()->data);
// node_mapping
*(reinterpret_cast<int64_t*>(data_ptr)) = node_mapping_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, node_map_data, node_mapping_size);
data_ptr += node_mapping_size;
// layer_offsets
*(reinterpret_cast<int64_t*>(data_ptr)) = layer_offsets_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, layer_off_data, layer_offsets_size);
data_ptr += layer_offsets_size;
// flow_offsets
*(reinterpret_cast<int64_t*>(data_ptr)) = flow_offsets_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, flow_off_data, flow_offsets_size);
data_ptr += flow_offsets_size;
// edge_mapping
*(reinterpret_cast<int64_t*>(data_ptr)) = edge_mapping_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, edge_map_data, edge_mapping_size);
data_ptr += edge_mapping_size;
// indices (CSR)
*(reinterpret_cast<int64_t*>(data_ptr)) = indices_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, indices, indices_size);
data_ptr += indices_size;
// edge_ids (CSR)
*(reinterpret_cast<int64_t*>(data_ptr)) = edge_ids_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, edge_ids, edge_ids_size);
data_ptr += edge_ids_size;
// indptr (CSR)
*(reinterpret_cast<int64_t*>(data_ptr)) = indptr_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, indptr, indptr_size);
data_ptr += indptr_size;
return total_size;
}
void DeserializeSampledSubgraph(char* data,
CSRPtr* csr,
IdArray* node_mapping,
IdArray* edge_mapping,
IdArray* layer_offsets,
IdArray* flow_offsets) {
// For each component, we first read its size at the
// begining of the buffer and then read its binary data
char* data_ptr = data;
// node_mapping
int64_t tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
int64_t num_vertices = tensor_size / sizeof(int64_t);
data_ptr += sizeof(int64_t);
*node_mapping = IdArray::Empty({static_cast<int64_t>(num_vertices)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* node_map_data = static_cast<dgl_id_t*>((*node_mapping)->data);
memcpy(node_map_data, data_ptr, tensor_size);
data_ptr += tensor_size;
// layer offsets
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
int64_t num_hops_add_one = tensor_size / sizeof(int64_t);
data_ptr += sizeof(int64_t);
*layer_offsets = IdArray::Empty({static_cast<int64_t>(num_hops_add_one)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* layer_off_data = static_cast<dgl_id_t*>((*layer_offsets)->data);
memcpy(layer_off_data, data_ptr, tensor_size);
data_ptr += tensor_size;
// flow offsets
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
int64_t num_hops = tensor_size / sizeof(int64_t);
data_ptr += sizeof(int64_t);
*flow_offsets = IdArray::Empty({static_cast<int64_t>(num_hops)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* flow_off_data = static_cast<dgl_id_t*>((*flow_offsets)->data);
memcpy(flow_off_data, data_ptr, tensor_size);
data_ptr += tensor_size;
// edge_mapping
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
int64_t num_edges = tensor_size / sizeof(int64_t);
data_ptr += sizeof(int64_t);
*edge_mapping = IdArray::Empty({static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* edge_mapping_data = static_cast<dgl_id_t*>((*edge_mapping)->data);
memcpy(edge_mapping_data, data_ptr, tensor_size);
data_ptr += tensor_size;
// Construct sub_csr_graph
// TODO(minjie): multigraph flag
*csr = CSRPtr(new CSR(num_vertices, num_edges, false));
// indices (CSR)
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
data_ptr += sizeof(int64_t);
dgl_id_t* col_list_out = static_cast<dgl_id_t*>((*csr)->indices()->data);
memcpy(col_list_out, data_ptr, tensor_size);
data_ptr += tensor_size;
// edge_ids (CSR)
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
data_ptr += sizeof(int64_t);
dgl_id_t* edge_ids = static_cast<dgl_id_t*>((*csr)->edge_ids()->data);
memcpy(edge_ids, data_ptr, tensor_size);
data_ptr += tensor_size;
// indptr (CSR)
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
data_ptr += sizeof(int64_t);
dgl_id_t* indptr_out = static_cast<dgl_id_t*>((*csr)->indptr()->data);
memcpy(indptr_out, data_ptr, tensor_size);
data_ptr += tensor_size;
}
} // namespace network
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file serialize.h
* \brief Serialization for DGL distributed training.
*/
#ifndef DGL_GRAPH_NETWORK_SERIALIZE_H_
#define DGL_GRAPH_NETWORK_SERIALIZE_H_
#include <dgl/sampler.h>
#include <dgl/immutable_graph.h>
namespace dgl {
namespace network {
/*!
* \brief Serialize sampled subgraph to binary data
* \param data pointer of data buffer
* \param csr subgraph csr
* \param node_mapping node mapping in NodeFlowIndex
* \param edge_mapping edge mapping in NodeFlowIndex
* \param layer_offsets layer offsets in NodeFlowIndex
* \param flow_offsets flow offsets in NodeFlowIndex
* \return the total size of the serialized binary data
*/
int64_t SerializeSampledSubgraph(char* data,
const CSRPtr csr,
const IdArray& node_mapping,
const IdArray& edge_mapping,
const IdArray& layer_offsets,
const IdArray& flow_offsets);
/*!
* \brief Deserialize sampled subgraph from binary data
* \param data pointer of data buffer
* \param csr subgraph csr
* \param node_mapping node mapping in NodeFlowIndex
* \param edge_mapping edge mapping in NodeFlowIndex
* \param layer_offsets layer offsets in NodeFlowIndex
* \param flow_offsets flow offsets in NodeFlowIndex
*/
void DeserializeSampledSubgraph(char* data,
CSRPtr* csr,
IdArray* node_mapping,
IdArray* edge_mapping,
IdArray* layer_offsets,
IdArray* flow_offsets);
// TODO(chao): we can add compression and decompression method here
} // namespace network
} // namespace dgl
#endif // DGL_GRAPH_NETWORK_SERIALIZE_H_
...@@ -5,9 +5,12 @@ ...@@ -5,9 +5,12 @@
*/ */
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <string.h>
#include <stdlib.h>
#include <time.h>
#include "socket_communicator.h" #include "socket_communicator.h"
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "../network.h"
#ifdef _WIN32 #ifdef _WIN32
#include <windows.h> #include <windows.h>
...@@ -18,34 +21,55 @@ ...@@ -18,34 +21,55 @@
namespace dgl { namespace dgl {
namespace network { namespace network {
const int kTimeOut = 10; // 10 minutes for socket timeout
const int kMaxConnection = 1024; // 1024 maximal socket connection
void SocketSender::AddReceiver(const char* ip, int port, int recv_id) { /////////////////////////////////////// SocketSender ///////////////////////////////////////////
dgl::network::Addr addr;
addr.ip_.assign(const_cast<char*>(ip));
addr.port_ = port; void SocketSender::AddReceiver(const char* addr, int recv_id) {
receiver_addr_map_[recv_id] = addr; CHECK_NOTNULL(addr);
if (recv_id < 0) {
LOG(FATAL) << "recv_id cannot be a negative number.";
}
std::vector<std::string> substring;
std::vector<std::string> ip_and_port;
SplitStringUsing(addr, "//", &substring);
// Check address format
if (substring[0] != "socket:" || substring.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. ";
}
// Get IP and port
SplitStringUsing(substring[1], ":", &ip_and_port);
if (ip_and_port.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. ";
}
IPAddr address;
address.ip = ip_and_port[0];
address.port = std::stoi(ip_and_port[1]);
receiver_addrs_[recv_id] = address;
msg_queue_[recv_id] = std::make_shared<MessageQueue>(queue_size_);
} }
bool SocketSender::Connect() { bool SocketSender::Connect() {
// Create N sockets for Receiver // Create N sockets for Receiver
for (const auto& r : receiver_addr_map_) { for (const auto& r : receiver_addrs_) {
int ID = r.first; int ID = r.first;
socket_map_[ID] = new TCPSocket(); sockets_[ID] = std::make_shared<TCPSocket>();
TCPSocket* client = socket_map_[ID]; TCPSocket* client_socket = sockets_[ID].get();
bool bo = false; bool bo = false;
int try_count = 0; int try_count = 0;
const char* ip = r.second.ip_.c_str(); const char* ip = r.second.ip.c_str();
int port = r.second.port_; int port = r.second.port;
while (bo == false && try_count < kMaxTryCount) { while (bo == false && try_count < kMaxTryCount) {
if (client->Connect(ip, port)) { if (client_socket->Connect(ip, port)) {
LOG(INFO) << "Connected to Receiver: " << ip << ":" << port; LOG(INFO) << "Connected to Receiver: " << ip << ":" << port;
bo = true; bo = true;
} else { } else {
LOG(ERROR) << "Cannot connect to Receiver: " << ip << ":" << port LOG(ERROR) << "Cannot connect to Receiver: " << ip << ":" << port
<< ", try again ..."; << ", try again ...";
bo = false;
try_count++; try_count++;
#ifdef _WIN32 #ifdef _WIN32
Sleep(1); Sleep(1);
...@@ -57,101 +81,200 @@ bool SocketSender::Connect() { ...@@ -57,101 +81,200 @@ bool SocketSender::Connect() {
if (bo == false) { if (bo == false) {
return bo; return bo;
} }
// Create a new thread for this socket connection
threads_[ID] = std::make_shared<std::thread>(
SendLoop,
client_socket,
msg_queue_[ID].get());
} }
return true; return true;
} }
int64_t SocketSender::Send(const char* data, int64_t size, int recv_id) { STATUS SocketSender::Send(Message msg, int recv_id) {
TCPSocket* client = socket_map_[recv_id]; CHECK_NOTNULL(msg.data);
// First sent the size of data CHECK_GT(msg.size, 0);
int64_t sent_bytes = 0; CHECK_GE(recv_id, 0);
while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) { // Add data message to message queue
int64_t max_len = sizeof(int64_t) - sent_bytes; STATUS code = msg_queue_[recv_id]->Add(msg);
int64_t tmp = client->Send( return code;
reinterpret_cast<char*>(&size)+sent_bytes,
max_len);
sent_bytes += tmp;
}
// Then send the data
sent_bytes = 0;
while (sent_bytes < size) {
int64_t max_len = size - sent_bytes;
int64_t tmp = client->Send(data+sent_bytes, max_len);
sent_bytes += tmp;
}
return size + sizeof(int64_t);
} }
void SocketSender::Finalize() { void SocketSender::Finalize() {
// Close all sockets // Send a signal to tell the msg_queue to finish its job
for (const auto& socket : socket_map_) { for (auto& mq : msg_queue_) {
TCPSocket* client = socket.second; // wait until queue is empty
if (client != nullptr) { while (mq.second->Empty() == false) {
client->Close(); #ifdef _WIN32
delete client; // just loop
client = nullptr; #else // !_WIN32
usleep(1000);
#endif // _WIN32
} }
int ID = mq.first;
mq.second->SignalFinished(ID);
}
// Block main thread until all socket-threads finish their jobs
for (auto& thread : threads_) {
thread.second->join();
}
// Clear all sockets
for (auto& socket : sockets_) {
socket.second->Close();
} }
delete buffer_;
} }
char* SocketSender::GetBuffer() { void SocketSender::SendLoop(TCPSocket* socket, MessageQueue* queue) {
return buffer_; CHECK_NOTNULL(socket);
CHECK_NOTNULL(queue);
bool exit = false;
while (!exit) {
Message msg;
STATUS code = queue->Remove(&msg);
if (code == QUEUE_CLOSE) {
msg.size = 0; // send an end-signal to receiver
exit = true;
}
// First send the size
// If exit == true, we will send zero size to reciever
int64_t sent_bytes = 0;
while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - sent_bytes;
int64_t tmp = socket->Send(
reinterpret_cast<char*>(&msg.size)+sent_bytes,
max_len);
CHECK_NE(tmp, -1);
sent_bytes += tmp;
}
// Then send the data
sent_bytes = 0;
while (sent_bytes < msg.size) {
int64_t max_len = msg.size - sent_bytes;
int64_t tmp = socket->Send(msg.data+sent_bytes, max_len);
CHECK_NE(tmp, -1);
sent_bytes += tmp;
}
// delete msg
if (msg.deallocator != nullptr) {
msg.deallocator(&msg);
}
}
} }
void SocketSender::SetBuffer(char* buffer) { /////////////////////////////////////// SocketReceiver ///////////////////////////////////////////
buffer_ = buffer;
}
bool SocketReceiver::Wait(const char* ip, bool SocketReceiver::Wait(const char* addr, int num_sender) {
int port, CHECK_NOTNULL(addr);
int num_sender, CHECK_GT(num_sender, 0);
int queue_size) { std::vector<std::string> substring;
CHECK_GE(num_sender, 1); std::vector<std::string> ip_and_port;
CHECK_GT(queue_size, 0); SplitStringUsing(addr, "//", &substring);
// Initialize message queue // Check address format
if (substring[0] != "socket:" || substring.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. ";
}
// Get IP and port
SplitStringUsing(substring[1], ":", &ip_and_port);
if (ip_and_port.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. ";
}
std::string ip = ip_and_port[0];
int port = stoi(ip_and_port[1]);
// Initialize message queue for each connection
num_sender_ = num_sender; num_sender_ = num_sender;
queue_size_ = queue_size; for (int i = 0; i < num_sender_; ++i) {
queue_ = new MessageQueue(queue_size_, num_sender_); msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
// Initialize socket, and socket_[0] is server socket }
socket_.resize(num_sender_+1); // Initialize socket and socket-thread
thread_.resize(num_sender_); server_socket_ = new TCPSocket();
socket_[0] = new TCPSocket(); server_socket_->SetTimeout(kTimeOut * 60 * 1000); // millsec
TCPSocket* server = socket_[0];
server->SetTimeout(kTimeOut * 60 * 1000); // millsec
// Bind socket // Bind socket
if (server->Bind(ip, port) == false) { if (server_socket_->Bind(ip.c_str(), port) == false) {
LOG(FATAL) << "Cannot bind to " << ip << ":" << port; LOG(FATAL) << "Cannot bind to " << ip << ":" << port;
return false;
} }
LOG(INFO) << "Bind to " << ip << ":" << port; LOG(INFO) << "Bind to " << ip << ":" << port;
// Listen // Listen
if (server->Listen(kMaxConnection) == false) { if (server_socket_->Listen(kMaxConnection) == false) {
LOG(FATAL) << "Cannot listen on " << ip << ":" << port; LOG(FATAL) << "Cannot listen on " << ip << ":" << port;
return false;
} }
LOG(INFO) << "Listen on " << ip << ":" << port << ", wait sender connect ..."; LOG(INFO) << "Listen on " << ip << ":" << port << ", wait sender connect ...";
// Accept all sender sockets // Accept all sender sockets
std::string accept_ip; std::string accept_ip;
int accept_port; int accept_port;
for (int i = 1; i <= num_sender_; ++i) { for (int i = 0; i < num_sender_; ++i) {
socket_[i] = new TCPSocket(); sockets_[i] = std::make_shared<TCPSocket>();
if (server->Accept(socket_[i], &accept_ip, &accept_port) == false) { if (server_socket_->Accept(sockets_[i].get(), &accept_ip, &accept_port) == false) {
LOG(FATAL) << "Error on accept socket."; LOG(WARNING) << "Error on accept socket.";
return false; return false;
} }
// create new thread for each socket // create new thread for each socket
thread_[i-1] = new std::thread(MsgHandler, socket_[i], queue_, i-1); threads_[i] = std::make_shared<std::thread>(
RecvLoop,
sockets_[i].get(),
msg_queue_[i].get());
LOG(INFO) << "Accept new sender: " << accept_ip << ":" << accept_port; LOG(INFO) << "Accept new sender: " << accept_ip << ":" << accept_port;
} }
return true; return true;
} }
void SocketReceiver::MsgHandler(TCPSocket* socket, MessageQueue* queue, int id) { STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
char* buffer = new char[kMaxBufferSize]; // loop until get a message
for (;;) { for (;;) {
for (auto& mq : msg_queue_) {
*send_id = mq.first;
// We use non-block remove here
STATUS code = msg_queue_[*send_id]->Remove(msg, false);
if (code == QUEUE_EMPTY) {
continue; // jump to the next queue
} else {
return code;
}
}
}
}
STATUS SocketReceiver::RecvFrom(Message* msg, int send_id) {
// Get message from specified message queue
STATUS code = msg_queue_[send_id]->Remove(msg);
return code;
}
void SocketReceiver::Finalize() {
// Send a signal to tell the message queue to finish its job
for (auto& mq : msg_queue_) {
// wait until queue is empty
while (mq.second->Empty() == false) {
#ifdef _WIN32
// just loop
#else // !_WIN32
usleep(1000);
#endif // _WIN32
}
int ID = mq.first;
mq.second->SignalFinished(ID);
}
// Block main thread until all socket-threads finish their jobs
for (auto& thread : threads_) {
thread.second->join();
}
// Clear all sockets
for (auto& socket : sockets_) {
socket.second->Close();
}
}
void SocketReceiver::RecvLoop(TCPSocket* socket, MessageQueue* queue) {
CHECK_NOTNULL(socket);
CHECK_NOTNULL(queue);
for (;;) {
// If main thread had finished its job
if (queue->EmptyAndNoMoreAdd()) {
return; // exit loop thread
}
// First recv the size // First recv the size
int64_t received_bytes = 0; int64_t received_bytes = 0;
int64_t data_size = 0; int64_t data_size = 0;
...@@ -160,54 +283,36 @@ void SocketReceiver::MsgHandler(TCPSocket* socket, MessageQueue* queue, int id) ...@@ -160,54 +283,36 @@ void SocketReceiver::MsgHandler(TCPSocket* socket, MessageQueue* queue, int id)
int64_t tmp = socket->Receive( int64_t tmp = socket->Receive(
reinterpret_cast<char*>(&data_size)+received_bytes, reinterpret_cast<char*>(&data_size)+received_bytes,
max_len); max_len);
CHECK_NE(tmp, -1);
received_bytes += tmp; received_bytes += tmp;
} }
// Data_size ==-99 is a special signal to tell if (data_size < 0) {
// the MsgHandler to exit the loop LOG(FATAL) << "Recv data error (data_size: " << data_size << ")";
if (data_size <= 0) { } else if (data_size == 0) {
queue->Signal(id); // This is an end-signal sent by client
break; return;
} } else {
// Then recv the data char* buffer = nullptr;
received_bytes = 0; try {
while (received_bytes < data_size) { buffer = new char[data_size];
int64_t max_len = data_size - received_bytes; } catch(const std::bad_alloc&) {
int64_t tmp = socket->Receive(buffer+received_bytes, max_len); LOG(FATAL) << "Cannot allocate enough memory for message, "
received_bytes += tmp; << "(message size: " << data_size << ")";
} }
queue->Add(buffer, data_size); received_bytes = 0;
} while (received_bytes < data_size) {
delete [] buffer; int64_t max_len = data_size - received_bytes;
} int64_t tmp = socket->Receive(buffer+received_bytes, max_len);
CHECK_NE(tmp, -1);
int64_t SocketReceiver::Recv(char* dest, int64_t max_size) { received_bytes += tmp;
// Get message from message queue }
return queue_->Remove(dest, max_size); Message msg;
} msg.data = buffer;
msg.size = data_size;
void SocketReceiver::Finalize() { msg.deallocator = DefaultMessageDeleter;
for (int i = 0; i <= num_sender_; ++i) { queue->Add(msg);
if (i != 0) { // write -99 signal to exit loop
int64_t data_size = -99;
queue_->Add(
reinterpret_cast<char*>(&data_size),
sizeof(int64_t));
}
if (socket_[i] != nullptr) {
socket_[i]->Close();
delete socket_[i];
socket_[i] = nullptr;
} }
} }
delete buffer_;
}
char* SocketReceiver::GetBuffer() {
return buffer_;
}
void SocketReceiver::SetBuffer(char* buffer) {
buffer_ = buffer;
} }
} // namespace network } // namespace network
......
...@@ -14,136 +14,172 @@ ...@@ -14,136 +14,172 @@
#include "communicator.h" #include "communicator.h"
#include "msg_queue.h" #include "msg_queue.h"
#include "tcp_socket.h" #include "tcp_socket.h"
#include "common.h"
namespace dgl { namespace dgl {
namespace network { namespace network {
using dgl::network::MessageQueue; static int kMaxTryCount = 1024; // maximal connection: 1024
using dgl::network::TCPSocket; static int kTimeOut = 10; // 10 minutes for socket timeout
using dgl::network::Sender; static int kMaxConnection = 1024; // maximal connection: 1024
using dgl::network::Receiver;
/*! /*!
* \breif Networking address * \breif Networking address
*/ */
struct Addr { struct IPAddr {
std::string ip_; std::string ip;
int port_; int port;
}; };
/*! /*!
* \brief Network Sender for DGL distributed training. * \brief SocketSender for DGL distributed training.
* *
* Sender is an abstract class that defines a set of APIs for sending * SocketSender is the communicator implemented by tcp socket.
* binary data over network. It can be implemented by different underlying
* networking libraries such TCP socket and ZMQ. One Sender can connect to
* multiple receivers, and it can send data to specified receiver via receiver's ID.
*/ */
class SocketSender : public Sender { class SocketSender : public Sender {
public: public:
/*! /*!
* \brief Add receiver address and it's ID to the namebook * \brief Sender constructor
* \param ip receviver's IP address * \param queue_size size of message queue
* \param port receiver's port */
explicit SocketSender(int64_t queue_size) : Sender(queue_size) {}
/*!
* \brief Add receiver's address and ID to the sender's namebook
* \param addr Networking address, e.g., 'socket://127.0.0.1:50091', 'mpi://0'
* \param id receiver's ID * \param id receiver's ID
*
* AddReceiver() is not thread-safe and only one thread can invoke this API.
*/ */
void AddReceiver(const char* ip, int port, int recv_id); void AddReceiver(const char* addr, int recv_id);
/*! /*!
* \brief Connect with all the Receivers * \brief Connect with all the Receivers
* \return True for sucess and False for fail * \return True for success and False for fail
*
* Connect() is not thread-safe and only one thread can invoke this API.
*/ */
bool Connect(); bool Connect();
/*! /*!
* \brief Send data to specified Receiver * \brief Send data to specified Receiver. Actually pushing message to message queue.
* \param data data buffer for sending * \param msg data message
* \param size data size for sending
* \param recv_id receiver's ID * \param recv_id receiver's ID
* \return bytes we sent * \return Status code
* > 0 : bytes we sent *
* - 1 : error * (1) The send is non-blocking. There is no guarantee that the message has been
* physically sent out when the function returns.
* (2) The communicator will assume the responsibility of the given message.
* (3) The API is multi-thread safe.
* (4) Messages sent to the same receiver are guaranteed to be received in the same order.
* There is no guarantee for messages sent to different receivers.
*/ */
int64_t Send(const char* data, int64_t size, int recv_id); STATUS Send(Message msg, int recv_id);
/*! /*!
* \brief Finalize Sender * \brief Finalize SocketSender
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/ */
void Finalize(); void Finalize();
/*! /*!
* \brief Get data buffer * \brief Communicator type: 'socket'
* \return buffer pointer
*/ */
char* GetBuffer(); inline std::string Type() const { return std::string("socket"); }
private:
/*! /*!
* \brief Set data buffer * \brief socket for each connection of receiver
*/ */
void SetBuffer(char* buffer); std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>> sockets_;
private:
/*! /*!
* \brief socket map * \brief receivers' address
*/ */
std::unordered_map<int, TCPSocket*> socket_map_; std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
/*! /*!
* \brief receiver address map * \brief message queue for each socket connection
*/ */
std::unordered_map<int, Addr> receiver_addr_map_; std::unordered_map<int /* receiver ID */, std::shared_ptr<MessageQueue>> msg_queue_;
/*! /*!
* \brief data buffer * \brief Independent thread for each socket connection
*/ */
char* buffer_ = nullptr; std::unordered_map<int /* receiver ID */, std::shared_ptr<std::thread>> threads_;
/*!
* \brief Send-loop for each socket in per-thread
* \param socket TCPSocket for current connection
* \param queue message_queue for current connection
*
* Note that, the SendLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue.
*/
static void SendLoop(TCPSocket* socket, MessageQueue* queue);
}; };
/*! /*!
* \brief Network Receiver for DGL distributed training. * \brief SocketReceiver for DGL distributed training.
* *
* Receiver is an abstract class that defines a set of APIs for receiving binary * SocketReceiver is the communicator implemented by tcp socket.
* data over network. It can be implemented by different underlying networking libraries
* such TCP socket and ZMQ. One Receiver can connect with multiple Senders, and it can receive
* data from these Senders concurrently via multi-threading and message queue.
*/ */
class SocketReceiver : public Receiver { class SocketReceiver : public Receiver {
public: public:
/*! /*!
* \brief Wait all of the Senders to connect * \brief Receiver constructor
* \param ip Receiver's IP address * \param queue_size size of message queue.
* \param port Receiver's port
* \param num_sender total number of Senders
* \param queue_size size of message queue
* \return True for sucess and False for fail
*/ */
bool Wait(const char* ip, int port, int num_sender, int queue_size); explicit SocketReceiver(int64_t queue_size) : Receiver(queue_size) {}
/*! /*!
* \brief Recv data from Sender (copy data from message queue) * \brief Wait for all the Senders to connect
* \param dest data buffer of destination * \param addr Networking address, e.g., 'socket://127.0.0.1:50051', 'mpi://0'
* \param max_size maximul size of data buffer * \param num_sender total number of Senders
* \return bytes we received * \return True for success and False for fail
* > 0 : bytes we received *
* - 1 : error * Wait() is not thread-safe and only one thread can invoke this API.
*/ */
int64_t Recv(char* dest, int64_t max_size); bool Wait(const char* addr, int num_sender);
/*! /*!
* \brief Finalize Receiver * \brief Recv data from Sender. Actually removing data from msg_queue.
* \param msg pointer of data message
* \param send_id which sender current msg comes from
* \return Status code
*
* (1) The Recv() API is blocking, which will not
* return until getting data from message queue.
* (2) The Recv() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/ */
void Finalize(); STATUS Recv(Message* msg, int* send_id);
/*!
* \brief Recv data from a specified Sender. Actually removing data from msg_queue.
* \param msg pointer of data message
* \param send_id sender's ID
* \return Status code
*
* (1) The RecvFrom() API is blocking, which will not
* return until getting data from message queue.
* (2) The RecvFrom() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/
STATUS RecvFrom(Message* msg, int send_id);
/*! /*!
* \brief Get data buffer * \brief Finalize SocketReceiver
* \return buffer pointer *
* Finalize() is not thread-safe and only one thread can invoke this API.
*/ */
char* GetBuffer(); void Finalize();
/*! /*!
* \brief Set data buffer * \brief Communicator type: 'socket'
*/ */
void SetBuffer(char* buffer); inline std::string Type() const { return std::string("socket"); }
private: private:
/*! /*!
...@@ -152,37 +188,34 @@ class SocketReceiver : public Receiver { ...@@ -152,37 +188,34 @@ class SocketReceiver : public Receiver {
int num_sender_; int num_sender_;
/*! /*!
* \brief maximal size of message queue * \brief server socket for listening connections
*/
int64_t queue_size_;
/*!
* \brief socket list
*/ */
std::vector<TCPSocket*> socket_; TCPSocket* server_socket_;
/*! /*!
* \brief Thread pool for socket connection * \brief socket for each client connections
*/ */
std::vector<std::thread*> thread_; std::unordered_map<int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>> sockets_;
/*! /*!
* \brief Message queue for communicator * \brief Message queue for each socket connection
*/ */
MessageQueue* queue_; std::unordered_map<int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>> msg_queue_;
/*! /*!
* \brief data buffer * \brief Independent thead for each socket connection
*/ */
char* buffer_ = nullptr; std::unordered_map<int /* Sender (virtual) ID */, std::shared_ptr<std::thread>> threads_;
/*! /*!
* \brief Process received message in independent threads * \brief Recv-loop for each socket in per-thread
* \param socket new accpeted socket * \param socket client socket
* \param queue message queue * \param queue message queue
* \param id producer_id *
* Note that, the RecvLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue.
*/ */
static void MsgHandler(TCPSocket* socket, MessageQueue* queue, int id); static void RecvLoop(TCPSocket* socket, MessageQueue* queue);
}; };
} // namespace network } // namespace network
......
/*! /*!
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* \file msg_queue.cc * \file graph_index_test.cc
* \brief Message queue for DGL distributed training. * \brief Test GraphIndex
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <dgl/graph.h> #include <dgl/graph.h>
......
...@@ -5,39 +5,100 @@ ...@@ -5,39 +5,100 @@
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string> #include <string>
#include <thread>
#include <vector>
#include "../src/graph/network/msg_queue.h" #include "../src/graph/network/msg_queue.h"
using std::string; using std::string;
using dgl::network::Message;
using dgl::network::MessageQueue; using dgl::network::MessageQueue;
TEST(MessageQueueTest, AddRemove) { TEST(MessageQueueTest, AddRemove) {
MessageQueue queue(5, 1); // size:5, num_of_producer:1 MessageQueue queue(5, 1); // size:5, num_of_producer:1
char buff[10]; // msg 1
queue.Add("111", 3); std::string str_1("111");
queue.Add("22", 2); Message msg_1 = {const_cast<char*>(str_1.data()), 3};
EXPECT_EQ(0, queue.Add("xxxx", 4, false)); // non-blocking add EXPECT_EQ(queue.Add(msg_1), ADD_SUCCESS);
queue.Remove(buff, 3); // msg 2
EXPECT_EQ(string(buff, 3), string("111")); std::string str_2("22");
queue.Remove(buff, 2); Message msg_2 = {const_cast<char*>(str_2.data()), 2};
EXPECT_EQ(string(buff, 2), string("22")); EXPECT_EQ(queue.Add(msg_2), ADD_SUCCESS);
queue.Add("33333", 5); // msg 3
queue.Remove(buff, 5); std::string str_3("xxxx");
EXPECT_EQ(string(buff, 5), string("33333")); Message msg_3 = {const_cast<char*>(str_3.data()), 4};
EXPECT_EQ(0, queue.Remove(buff, 10, false)); // non-blocking remove EXPECT_EQ(queue.Add(msg_3, false), QUEUE_FULL);
EXPECT_EQ(queue.Add("666666", 6), -1); // exceed buffer size // msg 4
queue.Add("55555", 5); Message msg_4;
EXPECT_EQ(queue.Remove(buff, 3), -1); // message too long EXPECT_EQ(queue.Remove(&msg_4), REMOVE_SUCCESS);
EXPECT_EQ(string(msg_4.data, msg_4.size), string("111"));
// msg 5
Message msg_5;
EXPECT_EQ(queue.Remove(&msg_5), REMOVE_SUCCESS);
EXPECT_EQ(string(msg_5.data, msg_5.size), string("22"));
// msg 6
std::string str_6("33333");
Message msg_6 = {const_cast<char*>(str_6.data()), 5};
EXPECT_EQ(queue.Add(msg_6), ADD_SUCCESS);
// msg 7
Message msg_7;
EXPECT_EQ(queue.Remove(&msg_7), REMOVE_SUCCESS);
EXPECT_EQ(string(msg_7.data, msg_7.size), string("33333"));
// msg 8
Message msg_8;
EXPECT_EQ(queue.Remove(&msg_8, false), QUEUE_EMPTY); // non-blocking remove
// msg 9
std::string str_9("666666");
Message msg_9 = {const_cast<char*>(str_9.data()), 6};
EXPECT_EQ(queue.Add(msg_9), MSG_GT_SIZE); // exceed queue size
// msg 10
std::string str_10("55555");
Message msg_10 = {const_cast<char*>(str_10.data()), 5};
EXPECT_EQ(queue.Add(msg_10), ADD_SUCCESS);
// msg 11
Message msg_11;
EXPECT_EQ(queue.Remove(&msg_11), REMOVE_SUCCESS);
} }
TEST(MessageQueueTest, EmptyAndNoMoreAdd) { TEST(MessageQueueTest, EmptyAndNoMoreAdd) {
MessageQueue queue(5, 2); // size:5, num_of_producer:2 MessageQueue queue(5, 2); // size:5, num_of_producer:2
char buff[10];
EXPECT_EQ(queue.EmptyAndNoMoreAdd(), false); EXPECT_EQ(queue.EmptyAndNoMoreAdd(), false);
queue.Signal(1); EXPECT_EQ(queue.Empty(), true);
queue.Signal(1); queue.SignalFinished(1);
queue.SignalFinished(1);
EXPECT_EQ(queue.EmptyAndNoMoreAdd(), false); EXPECT_EQ(queue.EmptyAndNoMoreAdd(), false);
queue.Signal(2); queue.SignalFinished(2);
EXPECT_EQ(queue.EmptyAndNoMoreAdd(), true);
}
const int kNumOfProducer = 100;
const int kNumOfMessage = 100;
std::string str_apple("apple");
void start_add(MessageQueue* queue, int id) {
for (int i = 0; i < kNumOfMessage; ++i) {
Message msg = {const_cast<char*>(str_apple.data()), 5};
EXPECT_EQ(queue->Add(msg), ADD_SUCCESS);
}
queue->SignalFinished(id);
}
TEST(MessageQueueTest, MultiThread) {
MessageQueue queue(100000, kNumOfProducer);
EXPECT_EQ(queue.EmptyAndNoMoreAdd(), false);
EXPECT_EQ(queue.Empty(), true);
std::vector<std::thread*> thread_pool;
for (int i = 0; i < kNumOfProducer; ++i) {
thread_pool.push_back(new std::thread(start_add, &queue, i));
}
for (int i = 0; i < kNumOfProducer*kNumOfMessage; ++i) {
Message msg;
EXPECT_EQ(queue.Remove(&msg), REMOVE_SUCCESS);
EXPECT_EQ(string(msg.data, msg.size), string("apple"));
}
for (int i = 0; i < kNumOfProducer; ++i) {
thread_pool[i]->join();
}
EXPECT_EQ(queue.EmptyAndNoMoreAdd(), true); EXPECT_EQ(queue.EmptyAndNoMoreAdd(), true);
EXPECT_EQ(queue.Remove(buff, 5), 0);
} }
\ No newline at end of file
/*! /*!
* Copyright (c) 2019 by Contributors * Copyright (c) 2019 by Contributors
* \file msg_queue.cc * \file socket_communicator_test.cc
* \brief Message queue for DGL distributed training. * \brief Test SocketCommunicator
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string.h> #include <string.h>
#include <string> #include <string>
#include <thread>
#include <vector>
#include "../src/graph/network/msg_queue.h"
#include "../src/graph/network/socket_communicator.h" #include "../src/graph/network/socket_communicator.h"
using std::string; using std::string;
using dgl::network::SocketSender; using dgl::network::SocketSender;
using dgl::network::SocketReceiver; using dgl::network::SocketReceiver;
using dgl::network::Message;
using dgl::network::DefaultMessageDeleter;
void start_client(); const int64_t kQueueSize = 500 * 1024;
bool start_server();
#ifndef WIN32 #ifndef WIN32
#include <unistd.h> const int kNumSender = 3;
const int kNumReceiver = 3;
const int kNumMessage = 10;
const char* ip_addr[] = {
"socket://127.0.0.1:50091",
"socket://127.0.0.1:50092",
"socket://127.0.0.1:50093"
};
static void start_client();
static void start_server(int id);
TEST(SocketCommunicatorTest, SendAndRecv) { TEST(SocketCommunicatorTest, SendAndRecv) {
std::thread client_thread(start_client); // start 10 client
start_server(); std::vector<std::thread*> client_thread;
client_thread.join(); for (int i = 0; i < kNumSender; ++i) {
client_thread.push_back(new std::thread(start_client));
}
// start 10 server
std::vector<std::thread*> server_thread;
for (int i = 0; i < kNumReceiver; ++i) {
server_thread.push_back(new std::thread(start_server, i));
}
for (int i = 0; i < kNumSender; ++i) {
client_thread[i]->join();
}
for (int i = 0; i < kNumReceiver; ++i) {
server_thread[i]->join();
}
} }
#else // WIN32 void start_client() {
sleep(2); // wait server start
SocketSender sender(kQueueSize);
for (int i = 0; i < kNumReceiver; ++i) {
sender.AddReceiver(ip_addr[i], i);
}
sender.Connect();
for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < kNumReceiver; ++n) {
char* str_data = new char[9];
memcpy(str_data, "123456789", 9);
Message msg = {str_data, 9};
msg.deallocator = DefaultMessageDeleter;
EXPECT_EQ(sender.Send(msg, n), ADD_SUCCESS);
}
}
for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < kNumReceiver; ++n) {
char* str_data = new char[9];
memcpy(str_data, "123456789", 9);
Message msg = {str_data, 9};
msg.deallocator = DefaultMessageDeleter;
EXPECT_EQ(sender.Send(msg, n), ADD_SUCCESS);
}
}
sender.Finalize();
}
void start_server(int id) {
SocketReceiver receiver(kQueueSize);
receiver.Wait(ip_addr[id], kNumSender);
for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < kNumSender; ++n) {
Message msg;
EXPECT_EQ(receiver.RecvFrom(&msg, n), REMOVE_SUCCESS);
EXPECT_EQ(string(msg.data, msg.size), string("123456789"));
msg.deallocator(&msg);
}
}
for (int n = 0; n < kNumSender*kNumMessage; ++n) {
Message msg;
int recv_id;
EXPECT_EQ(receiver.Recv(&msg, &recv_id), REMOVE_SUCCESS);
EXPECT_EQ(string(msg.data, msg.size), string("123456789"));
msg.deallocator(&msg);
}
receiver.Finalize();
}
#else
#include <windows.h> #include <windows.h>
#include <winsock2.h> #include <winsock2.h>
...@@ -37,6 +115,9 @@ void sleep(int seconds) { ...@@ -37,6 +115,9 @@ void sleep(int seconds) {
Sleep(seconds * 1000); Sleep(seconds * 1000);
} }
static void start_client();
static bool start_server();
DWORD WINAPI _ClientThreadFunc(LPVOID param) { DWORD WINAPI _ClientThreadFunc(LPVOID param) {
start_client(); start_client();
return 0; return 0;
...@@ -70,24 +151,26 @@ TEST(SocketCommunicatorTest, SendAndRecv) { ...@@ -70,24 +151,26 @@ TEST(SocketCommunicatorTest, SendAndRecv) {
::WSACleanup(); ::WSACleanup();
} }
#endif // WIN32 static void start_client() {
void start_client() {
const char * msg = "123456789";
sleep(1); sleep(1);
SocketSender sender; SocketSender sender(kQueueSize);
sender.AddReceiver("127.0.0.1", 2049, 0); sender.AddReceiver("socket://127.0.0.1:8001", 0);
sender.Connect(); sender.Connect();
sender.Send(msg, 9, 0); char* str_data = new char[9];
memcpy(str_data, "123456789", 9);
Message msg = {str_data, 9};
msg.deallocator = DefaultMessageDeleter;
sender.Send(msg, 0);
sender.Finalize(); sender.Finalize();
} }
bool start_server() { static bool start_server() {
char serbuff[10]; SocketReceiver receiver(kQueueSize);
memset(serbuff, '\0', 10); receiver.Wait("socket://127.0.0.1:8001", 1);
SocketReceiver receiver; Message msg;
receiver.Wait("127.0.0.1", 2049, 1, 500 * 1024); EXPECT_EQ(receiver.RecvFrom(&msg, 0), REMOVE_SUCCESS);
receiver.Recv(serbuff, 9);
receiver.Finalize(); receiver.Finalize();
return string("123456789") == string(serbuff); return string("123456789") == string(msg.data, msg.size);
} }
#endif
/*!
* Copyright (c) 2019 by Contributors
* \file string_test.cc
* \brief Test String Common
*/
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "../src/graph/network/common.h"
using dgl::network::SplitStringUsing;
using dgl::network::StringPrintf;
using dgl::network::SStringPrintf;
using dgl::network::StringAppendF;
TEST(SplitStringTest, SplitStringUsingCompoundDelim) {
std::string full(" apple \torange ");
std::vector<std::string> subs;
SplitStringUsing(full, " \t", &subs);
EXPECT_EQ(subs.size(), 2);
EXPECT_EQ(subs[0], std::string("apple"));
EXPECT_EQ(subs[1], std::string("orange"));
}
TEST(SplitStringTest, testSplitStringUsingSingleDelim) {
std::string full(" apple orange ");
std::vector<std::string> subs;
SplitStringUsing(full, " ", &subs);
EXPECT_EQ(subs.size(), 2);
EXPECT_EQ(subs[0], std::string("apple"));
EXPECT_EQ(subs[1], std::string("orange"));
}
TEST(SplitStringTest, testSplitingNoDelimString) {
std::string full("apple");
std::vector<std::string> subs;
SplitStringUsing(full, " ", &subs);
EXPECT_EQ(subs.size(), 1);
EXPECT_EQ(subs[0], std::string("apple"));
}
TEST(StringPrintf, normal) {
using std::string;
EXPECT_EQ(StringPrintf("%d", 1), string("1"));
string target;
SStringPrintf(&target, "%d", 1);
EXPECT_EQ(target, string("1"));
StringAppendF(&target, "%d", 2);
EXPECT_EQ(target, string("12"));
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment