Unverified Commit c3516f1a authored by Chao Ma's avatar Chao Ma Committed by GitHub
Browse files

[Network] Refactoring Communicator (#679)

* Refactoring Communicator

* fix lint

* change non-const reference

* add header file

* use MemoryBuffer

* update PR

* fix bug on csr shape

* zero-copy msg_queue

* fix lint

* fix lint

* fix lint

* add header file

* fix windows build error

* fix windows build error

* update

* fix lint

* update

* fix lint

* fix lint

* add more test

* fix windows test

* update windows test

* update windows test

* update windows test

* update

* fix lint

* fix lint

* update

* update

* update

* update

* use STATUS code

* update test

* remove mem_cpy

* fix lint

* update

* finish

* ConstructNFTensor

* add test for deallocator

* update

* fix lint
parent fc9d30fa
......@@ -3,20 +3,20 @@ from ...network import _send_nodeflow, _recv_nodeflow
from ...network import _create_sender, _create_receiver
from ...network import _finalize_sender, _finalize_receiver
from ...network import _add_receiver_addr, _sender_connect
from ...network import _receiver_wait, _send_end_signal
from ...network import _receiver_wait, _send_sampler_end_signal
from multiprocessing import Pool
from abc import ABCMeta, abstractmethod
class SamplerPool(object):
"""SamplerPool is an abstract class, in which the worker method
"""SamplerPool is an abstract class, in which the worker() method
should be implemented by users. SamplerPool will fork() N (N = num_worker)
child processes, and each process will perform worker() method independently.
Note that, the fork() API will use shared memory for N process and the OS will
perfrom copy-on-write only when developers write that piece of memory. So fork N
processes and load N copy of graph will not increase the memory overhead.
Note that, the fork() API uses shared memory for N processes and the OS will
perfrom copy-on-write on that only when developers write that piece of memory.
So fork N processes and load N copies of graph will not increase the memory overhead.
Users can use this class like this:
For example, users can use this class like this:
class MySamplerPool(SamplerPool):
......@@ -37,13 +37,13 @@ class SamplerPool(object):
Parameters
----------
num_worker : int
number of worker (number of child process)
number of child process
args : arguments
any arguments passed by user
"""
p = Pool()
for i in range(num_worker):
print("Start child process %d ..." % i)
print("Start child sampler process %d ..." % i)
p.apply_async(self.worker, args=(args,))
# Waiting for all subprocesses done ...
p.close()
......@@ -51,7 +51,7 @@ class SamplerPool(object):
@abstractmethod
def worker(self, args):
"""User-defined function
"""User-defined function for worker
Parameters
----------
......@@ -63,28 +63,34 @@ class SamplerPool(object):
class SamplerSender(object):
"""SamplerSender for DGL distributed training.
Users use SamplerSender to send sampled subgraph (NodeFlow)
to remote SamplerReceiver. Note that a SamplerSender can connect
to multiple SamplerReceiver.
Users use SamplerSender to send sampled subgraphs (NodeFlow)
to remote SamplerReceiver. Note that, a SamplerSender can connect
to multiple SamplerReceiver currently. The underlying implementation
will send different subgraphs to different SamplerReceiver in parallel
via multi-threading.
Parameters
----------
namebook : dict
address namebook of SamplerReceiver, where
key is recevier's ID and value is receiver's address, e.g.,
IP address namebook of SamplerReceiver, where the
key is recevier's ID (start from 0) and value is receiver's address, e.g.,
{ 0:'168.12.23.45:50051',
1:'168.12.23.21:50051',
2:'168.12.46.12:50051' }
net_type : str
networking type, e.g., 'socket' (default) or 'mpi'.
"""
def __init__(self, namebook):
def __init__(self, namebook, net_type='socket'):
assert len(namebook) > 0, 'namebook cannot be empty.'
assert net_type in ('socket', 'mpi'), 'Unknown network type.'
self._namebook = namebook
self._sender = _create_sender()
self._sender = _create_sender(net_type)
for ID, addr in self._namebook.items():
vec = addr.split(':')
_add_receiver_addr(self._sender, vec[0], int(vec[1]), ID)
ip_port = addr.split(':')
assert len(ip_port) == 2, 'Uncorrect format of IP address.'
_add_receiver_addr(self._sender, ip_port[0], int(ip_port[1]), ID)
_sender_connect(self._sender)
def __del__(self):
......@@ -93,36 +99,58 @@ class SamplerSender(object):
_finalize_sender(self._sender)
def send(self, nodeflow, recv_id):
"""Send sampled subgraph (NodeFlow) to remote trainer.
"""Send sampled subgraph (NodeFlow) to remote trainer. Note that,
the send() API is non-blocking and it returns immediately if the
underlying message queue is not full.
Parameters
----------
nodeflow : NodeFlow
sampled NodeFlow object
sampled NodeFlow
recv_id : int
receiver ID
receiver's ID
"""
assert recv_id >= 0, 'recv_id cannot be a negative number.'
_send_nodeflow(self._sender, nodeflow, recv_id)
def batch_send(self, nf_list, id_list):
"""Send a batch of subgraphs (Nodeflow) to remote trainer. Note that,
the batch_send() API is non-blocking and it returns immediately if the
underlying message queue is not full.
Parameters
----------
nf_list : list
a list of NodeFlow object
id_list : list
a list of recv_id
"""
assert len(nf_list) > 0, 'nf_list cannot be empty.'
assert len(nf_list) == len(id_list), 'The length of nf_list must be equal to id_list.'
for i in range(len(nf_list)):
assert id_list[i] >= 0, 'recv_id cannot be a negative number.'
_send_nodeflow(self._sender, nf_list[i], id_list[i])
def signal(self, recv_id):
"""Whene samplling of each epoch is finished, users can
invoke this API to tell SamplerReceiver it has finished its job.
"""When the samplling of each epoch is finished, users can
invoke this API to tell SamplerReceiver that sampler has finished its job.
Parameters
----------
recv_id : int
receiver ID
receiver's ID
"""
_send_end_signal(self._sender, recv_id)
assert recv_id >= 0, 'recv_id cannot be a negative number.'
_send_sampler_end_signal(self._sender, recv_id)
class SamplerReceiver(object):
"""SamplerReceiver for DGL distributed training.
Users use SamplerReceiver to receive sampled subgraph (NodeFlow)
Users use SamplerReceiver to receive sampled subgraphs (NodeFlow)
from remote SamplerSender. Note that SamplerReceiver can receive messages
from multiple SamplerSenders concurrently by given the num_sender parameter.
Note that, only when all SamplerSenders connect to SamplerReceiver, receiver
can start its job.
Only when all SamplerSenders connected to SamplerReceiver successfully,
SamplerReceiver can start its job.
Parameters
----------
......@@ -132,15 +160,20 @@ class SamplerReceiver(object):
address of SamplerReceiver, e.g., '127.0.0.1:50051'
num_sender : int
total number of SamplerSender
net_type : str
networking type, e.g., 'socket' (default) or 'mpi'.
"""
def __init__(self, graph, addr, num_sender):
def __init__(self, graph, addr, num_sender, net_type='socket'):
assert num_sender > 0, 'num_sender must be large than zero.'
assert net_type in ('socket', 'mpi'), 'Unknown network type.'
self._graph = graph
self._addr = addr
self._num_sender = num_sender
self._tmp_count = 0
self._receiver = _create_receiver()
vec = self._addr.split(':')
_receiver_wait(self._receiver, vec[0], int(vec[1]), self._num_sender);
self._receiver = _create_receiver(net_type)
ip_port = addr.split(':')
assert len(ip_port) == 2, 'Uncorrect format of IP address.'
_receiver_wait(self._receiver, ip_port[0], int(ip_port[1]), num_sender);
def __del__(self):
"""Finalize Receiver
......@@ -148,7 +181,7 @@ class SamplerReceiver(object):
_finalize_receiver(self._receiver)
def __iter__(self):
"""Iterator
"""Sampler iterator
"""
return self
......@@ -157,10 +190,10 @@ class SamplerReceiver(object):
"""
while True:
res = _recv_nodeflow(self._receiver, self._graph)
if isinstance(res, int):
if isinstance(res, int): # recv an end-signal
self._tmp_count += 1
if self._tmp_count == self._num_sender:
self._tmp_count = 0
raise StopIteration
else:
return res
return res # recv a nodeflow
......@@ -7,13 +7,30 @@ from . import utils
_init_api("dgl.network")
_CONTROL_NODEFLOW = 0
_CONTROL_END_SIGNAL = 1
def _create_sender():
################################ Common Network Components ##################################
def _create_sender(net_type):
"""Create a Sender communicator via C api
Parameters
----------
net_type : str
'socket' or 'mpi'
"""
return _CAPI_DGLSenderCreate()
assert net_type in ('socket', 'mpi'), 'Unknown network type.'
return _CAPI_DGLSenderCreate(net_type)
def _create_receiver(net_type):
"""Create a Receiver communicator via C api
Parameters
----------
net_type : str
'socket' or 'mpi'
"""
assert net_type in ('socket', 'mpi'), 'Unknown network type.'
return _CAPI_DGLReceiverCreate(net_type)
def _finalize_sender(sender):
"""Finalize Sender communicator
......@@ -25,6 +42,11 @@ def _finalize_sender(sender):
"""
_CAPI_DGLFinalizeSender(sender)
def _finalize_receiver(receiver):
"""Finalize Receiver Communicator
"""
_CAPI_DGLFinalizeReceiver(receiver)
def _add_receiver_addr(sender, ip_addr, port, recv_id):
"""Add Receiver IP address to namebook
......@@ -39,6 +61,7 @@ def _add_receiver_addr(sender, ip_addr, port, recv_id):
recv_id : int
Receiver ID
"""
assert recv_id >= 0, 'recv_id cannot be a negative number.'
_CAPI_DGLSenderAddReceiver(sender, ip_addr, int(port), int(recv_id))
def _sender_connect(sender):
......@@ -51,6 +74,27 @@ def _sender_connect(sender):
"""
_CAPI_DGLSenderConnect(sender)
def _receiver_wait(receiver, ip_addr, port, num_sender):
"""Wait all Sender to connect..
Parameters
----------
receiver : ctypes.c_void_p
C Receiver handle
ip_addr : str
IP address of Receiver
port : int
port of Receiver
num_sender : int
total number of Sender
"""
assert num_sender >= 0, 'num_sender cannot be a negative number.'
_CAPI_DGLReceiverWait(receiver, ip_addr, int(port), int(num_sender))
################################ Distributed Sampler Components ################################
def _send_nodeflow(sender, nodeflow, recv_id):
"""Send sampled subgraph (Nodeflow) to remote Receiver.
......@@ -63,12 +107,13 @@ def _send_nodeflow(sender, nodeflow, recv_id):
recv_id : int
Receiver ID
"""
assert recv_id >= 0, 'recv_id cannot be a negative number.'
gidx = nodeflow._graph
node_mapping = nodeflow._node_mapping.todgltensor()
edge_mapping = nodeflow._edge_mapping.todgltensor()
layers_offsets = utils.toindex(nodeflow._layer_offsets).todgltensor()
flows_offsets = utils.toindex(nodeflow._block_offsets).todgltensor()
_CAPI_SenderSendSubgraph(sender,
_CAPI_SenderSendNodeFlow(sender,
int(recv_id),
gidx,
node_mapping,
......@@ -76,7 +121,7 @@ def _send_nodeflow(sender, nodeflow, recv_id):
layers_offsets,
flows_offsets)
def _send_end_signal(sender, recv_id):
def _send_sampler_end_signal(sender, recv_id):
"""Send an epoch-end signal to remote Receiver.
Parameters
......@@ -86,33 +131,8 @@ def _send_end_signal(sender, recv_id):
recv_id : int
Receiver ID
"""
_CAPI_SenderSendEndSignal(sender, int(recv_id))
def _create_receiver():
"""Create a Receiver communicator via C api
"""
return _CAPI_DGLReceiverCreate()
def _finalize_receiver(receiver):
"""Finalize Receiver Communicator
"""
_CAPI_DGLFinalizeReceiver(receiver)
def _receiver_wait(receiver, ip_addr, port, num_sender):
"""Wait all Sender to connect..
Parameters
----------
receiver : ctypes.c_void_p
C Receiver handle
ip_addr : str
IP address of Receiver
port : int
port of Receiver
num_sender : int
total number of Sender
"""
_CAPI_DGLReceiverWait(receiver, ip_addr, int(port), int(num_sender))
assert recv_id >= 0, 'recv_id cannot be a negative number.'
_CAPI_SenderSendSamplerEndSignal(sender, int(recv_id))
def _recv_nodeflow(receiver, graph):
"""Receive sampled subgraph (NodeFlow) from remote sampler.
......@@ -126,15 +146,10 @@ def _recv_nodeflow(receiver, graph):
Returns
-------
NodeFlow
Sampled NodeFlow object
NodeFlow or an end-signal
"""
res = _CAPI_ReceiverRecvSubgraph(receiver)
res = _CAPI_ReceiverRecvNodeFlow(receiver)
if isinstance(res, int):
if res == _CONTROL_END_SIGNAL:
return _CONTROL_END_SIGNAL
else:
raise RuntimeError('Got unexpected control code {}'.format(res))
return res
else:
# res is of type List<NodeFlowObject>
return NodeFlow(graph, res[0])
return NodeFlow(graph, res)
......@@ -3,55 +3,120 @@
* \file graph/network.cc
* \brief DGL networking related APIs
*/
#include "./network.h"
#include <dgl/runtime/container.h>
#include <dgl/runtime/ndarray.h>
#include <dgl/packed_func_ext.h>
#include "./network.h"
#include <dgl/immutable_graph.h>
#include <dgl/nodeflow.h>
#include <unordered_map>
#include "./network/communicator.h"
#include "./network/socket_communicator.h"
#include "./network/serialize.h"
#include "../c_api_common.h"
#include "./network/msg_queue.h"
#include "./network/common.h"
using dgl::network::StringPrintf;
using namespace dgl::runtime;
namespace dgl {
namespace network {
// Wrapper for Send api
static void SendData(network::Sender* sender,
const char* data,
int64_t size,
int recv_id) {
int64_t send_size = sender->Send(data, size, recv_id);
if (send_size <= 0) {
LOG(FATAL) << "Send error (size: " << send_size << ")";
void MsgMeta::AddArray(const NDArray& array) {
// We first write the ndim to the data_shape_
data_shape_.push_back(static_cast<int64_t>(array->ndim));
// Then we write the data shape
for (int i = 0; i < array->ndim; ++i) {
data_shape_.push_back(array->shape[i]);
}
ndarray_count_++;
}
char* MsgMeta::Serialize(int64_t* size) {
char* buffer = nullptr;
int64_t buffer_size = 0;
buffer_size += sizeof(msg_type_);
if (ndarray_count_ != 0) {
buffer_size += sizeof(ndarray_count_);
buffer_size += sizeof(data_shape_.size());
buffer_size += sizeof(int64_t) * data_shape_.size();
}
buffer = new char[buffer_size];
char* pointer = buffer;
// Write msg_type_
*(reinterpret_cast<int*>(pointer)) = msg_type_;
pointer += sizeof(msg_type_);
if (ndarray_count_ != 0) {
// Write ndarray_count_
*(reinterpret_cast<int*>(pointer)) = ndarray_count_;
pointer += sizeof(ndarray_count_);
// Write size of data_shape_
*(reinterpret_cast<size_t*>(pointer)) = data_shape_.size();
pointer += sizeof(data_shape_.size());
// Write data of data_shape_
memcpy(pointer,
reinterpret_cast<char*>(data_shape_.data()),
sizeof(int64_t) * data_shape_.size());
}
*size = buffer_size;
return buffer;
}
// Wrapper for Recv api
static void RecvData(network::Receiver* receiver,
char* dest,
int64_t max_size) {
int64_t recv_size = receiver->Recv(dest, max_size);
if (recv_size <= 0) {
LOG(FATAL) << "Receive error (size: " << recv_size << ")";
void MsgMeta::Deserialize(char* buffer, int64_t size) {
int64_t data_size = 0;
// Read mesg_type_
msg_type_ = *(reinterpret_cast<int*>(buffer));
buffer += sizeof(int);
data_size += sizeof(int);
if (data_size < size) {
// Read ndarray_count_
ndarray_count_ = *(reinterpret_cast<int*>(buffer));
buffer += sizeof(int);
data_size += sizeof(int);
// Read size of data_shape_
size_t count = *(reinterpret_cast<size_t*>(buffer));
buffer += sizeof(size_t);
data_size += sizeof(size_t);
data_shape_.resize(count);
// Read data of data_shape_
memcpy(data_shape_.data(), buffer,
count * sizeof(int64_t));
data_size += count * sizeof(int64_t);
}
CHECK_EQ(data_size, size);
}
////////////////////////////////// Basic Networking Components ////////////////////////////////
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
network::Sender* sender = new network::SocketSender();
try {
char* buffer = new char[kMaxBufferSize];
sender->SetBuffer(buffer);
} catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for sender buffer: " << kMaxBufferSize;
std::string type = args[0];
network::Sender* sender = nullptr;
if (type == "socket") {
sender = new network::SocketSender(kQueueSize);
} else {
LOG(FATAL) << "Unknown communicator type: " << type;
}
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(sender);
*rv = chandle;
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
std::string type = args[0];
network::Receiver* receiver = nullptr;
if (type == "socket") {
receiver = new network::SocketReceiver(kQueueSize);
} else {
LOG(FATAL) << "Unknown communicator type: " << type;
}
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(receiver);
*rv = chandle;
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeSender")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
......@@ -59,6 +124,13 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeSender")
sender->Finalize();
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
receiver->Finalize();
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
......@@ -66,7 +138,13 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderAddReceiver")
int port = args[2];
int recv_id = args[3];
network::Sender* sender = static_cast<network::Sender*>(chandle);
sender->AddReceiver(ip.c_str(), port, recv_id);
std::string addr;
if (sender->Type() == "socket") {
addr = StringPrintf("socket://%s:%d", ip.c_str(), port);
} else {
LOG(FATAL) << "Unknown communicator type: " << sender->Type();
}
sender->AddReceiver(addr.c_str(), recv_id);
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect")
......@@ -78,104 +156,218 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderConnect")
}
});
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSubgraph")
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
std::string ip = args[1];
int port = args[2];
int num_sender = args[3];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
std::string addr;
if (receiver->Type() == "socket") {
addr = StringPrintf("socket://%s:%d", ip.c_str(), port);
} else {
LOG(FATAL) << "Unknown communicator type: " << receiver->Type();
}
if (receiver->Wait(addr.c_str(), num_sender) == false) {
LOG(FATAL) << "Wait sender socket failed.";
}
});
////////////////////////// Distributed Sampler Components ////////////////////////////////
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendNodeFlow")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
int recv_id = args[1];
// TODO(minjie): could simply use NodeFlow nf = args[2];
GraphRef g = args[2];
const IdArray node_mapping = args[3];
const IdArray edge_mapping = args[4];
const IdArray layer_offsets = args[5];
const IdArray flow_offsets = args[6];
NDArray node_mapping = args[3];
NDArray edge_mapping = args[4];
NDArray layer_offsets = args[5];
NDArray flow_offsets = args[6];
auto ptr = std::dynamic_pointer_cast<ImmutableGraph>(g.sptr());
CHECK(ptr) << "only immutable graph is allowed in send/recv";
network::Sender* sender = static_cast<network::Sender*>(chandle);
auto csr = ptr->GetInCSR();
// Write control message
char* buffer = sender->GetBuffer();
*buffer = CONTROL_NODEFLOW;
// Serialize nodeflow to data buffer
int64_t data_size = network::SerializeSampledSubgraph(
buffer+sizeof(CONTROL_NODEFLOW),
csr,
node_mapping,
edge_mapping,
layer_offsets,
flow_offsets);
CHECK_GT(data_size, 0);
data_size += sizeof(CONTROL_NODEFLOW);
// Send msg via network
SendData(sender, buffer, data_size, recv_id);
// Create a message for the meta data of ndarray
NDArray indptr = csr->indptr();
NDArray indice = csr->indices();
NDArray edge_ids = csr->edge_ids();
MsgMeta msg(kNodeFlowMsg);
msg.AddArray(node_mapping);
msg.AddArray(edge_mapping);
msg.AddArray(layer_offsets);
msg.AddArray(flow_offsets);
msg.AddArray(indptr);
msg.AddArray(indice);
msg.AddArray(edge_ids);
// send meta message
int64_t size = 0;
char* data = msg.Serialize(&size);
network::Sender* sender = static_cast<network::Sender*>(chandle);
Message send_msg;
send_msg.data = data;
send_msg.size = size;
send_msg.deallocator = DefaultMessageDeleter;
CHECK_NE(sender->Send(send_msg, recv_id), -1);
// send node_mapping
Message node_mapping_msg;
node_mapping_msg.data = static_cast<char*>(node_mapping->data);
node_mapping_msg.size = node_mapping.GetSize();
node_mapping_msg.aux_handler = &node_mapping;
node_mapping_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(node_mapping_msg, recv_id), -1);
// send edege_mapping
Message edge_mapping_msg;
edge_mapping_msg.data = static_cast<char*>(edge_mapping->data);
edge_mapping_msg.size = edge_mapping.GetSize();
edge_mapping_msg.aux_handler = &edge_mapping;
edge_mapping_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(edge_mapping_msg, recv_id), -1);
// send layer_offsets
Message layer_offsets_msg;
layer_offsets_msg.data = static_cast<char*>(layer_offsets->data);
layer_offsets_msg.size = layer_offsets.GetSize();
layer_offsets_msg.aux_handler = &layer_offsets;
layer_offsets_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(layer_offsets_msg, recv_id), -1);
// send flow_offset
Message flow_offsets_msg;
flow_offsets_msg.data = static_cast<char*>(flow_offsets->data);
flow_offsets_msg.size = flow_offsets.GetSize();
flow_offsets_msg.aux_handler = &flow_offsets;
flow_offsets_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(flow_offsets_msg, recv_id), -1);
// send csr->indptr
Message indptr_msg;
indptr_msg.data = static_cast<char*>(indptr->data);
indptr_msg.size = indptr.GetSize();
indptr_msg.aux_handler = &indptr;
indptr_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(indptr_msg, recv_id), -1);
// send csr->indices
Message indices_msg;
indices_msg.data = static_cast<char*>(indice->data);
indices_msg.size = indice.GetSize();
indices_msg.aux_handler = &indice;
indices_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(indices_msg, recv_id), -1);
// send csr->edge_ids
Message edge_ids_msg;
edge_ids_msg.data = static_cast<char*>(csr->edge_ids()->data);
edge_ids_msg.size = csr->edge_ids().GetSize();
edge_ids_msg.aux_handler = &edge_ids;
edge_ids_msg.deallocator = NDArrayDeleter;
CHECK_NE(sender->Send(edge_ids_msg, recv_id), -1);
});
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendEndSignal")
DGL_REGISTER_GLOBAL("network._CAPI_SenderSendSamplerEndSignal")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
int recv_id = args[1];
MsgMeta msg(kEndMsg);
int64_t size = 0;
char* data = msg.Serialize(&size);
network::Sender* sender = static_cast<network::Sender*>(chandle);
char* buffer = sender->GetBuffer();
*buffer = CONTROL_END_SIGNAL;
// Send msg via network
SendData(sender, buffer, sizeof(CONTROL_END_SIGNAL), recv_id);
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
network::Receiver* receiver = new network::SocketReceiver();
try {
char* buffer = new char[kMaxBufferSize];
receiver->SetBuffer(buffer);
} catch (const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for receiver buffer: " << kMaxBufferSize;
}
CommunicatorHandle chandle = static_cast<CommunicatorHandle>(receiver);
*rv = chandle;
Message send_msg = {data, size};
send_msg.deallocator = DefaultMessageDeleter;
CHECK_NE(sender->Send(send_msg, recv_id), -1);
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLFinalizeReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
receiver->Finalize();
});
DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverWait")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
std::string ip = args[1];
int port = args[2];
int num_sender = args[3];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
receiver->Wait(ip.c_str(), port, num_sender, kQueueSize);
});
static void ConstructNFTensor(DLTensor *tensor, char* data, int64_t shape_0) {
tensor->data = data;
tensor->ctx = DLContext{kDLCPU, 0};
tensor->ndim = 1;
tensor->dtype = DLDataType{kDLInt, 64, 1};
tensor->shape = new int64_t[1];
tensor->shape[0] = shape_0;
tensor->byte_offset = 0;
}
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvSubgraph")
DGL_REGISTER_GLOBAL("network._CAPI_ReceiverRecvNodeFlow")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
CommunicatorHandle chandle = args[0];
network::Receiver* receiver = static_cast<network::SocketReceiver*>(chandle);
// Recv data from network
char* buffer = receiver->GetBuffer();
RecvData(receiver, buffer, kMaxBufferSize);
int control = *buffer;
if (control == CONTROL_NODEFLOW) {
int send_id = 0;
Message recv_msg;
receiver->Recv(&recv_msg, &send_id);
MsgMeta msg(recv_msg.data, recv_msg.size);
recv_msg.deallocator(&recv_msg);
if (msg.msg_type() == kNodeFlowMsg) {
CHECK_EQ(msg.ndarray_count() * 2, msg.data_shape_.size());
NodeFlow nf = NodeFlow::Create();
CSRPtr csr;
// Deserialize nodeflow from recv_data_buffer
network::DeserializeSampledSubgraph(buffer+sizeof(CONTROL_NODEFLOW),
&(csr),
&(nf->node_mapping),
&(nf->edge_mapping),
&(nf->layer_offsets),
&(nf->flow_offsets));
// node_mapping
Message array_0;
CHECK_NE(receiver->RecvFrom(&array_0, send_id), -1);
CHECK_EQ(msg.data_shape_[0], 1);
DLTensor node_mapping_tensor;
ConstructNFTensor(&node_mapping_tensor, array_0.data, msg.data_shape_[1]);
DLManagedTensor *node_mapping_managed_tensor = new DLManagedTensor();
node_mapping_managed_tensor->dl_tensor = node_mapping_tensor;
nf->node_mapping = NDArray::FromDLPack(node_mapping_managed_tensor);
// edge_mapping
Message array_1;
CHECK_NE(receiver->RecvFrom(&array_1, send_id), -1);
CHECK_EQ(msg.data_shape_[2], 1);
DLTensor edge_mapping_tensor;
ConstructNFTensor(&edge_mapping_tensor, array_1.data, msg.data_shape_[3]);
DLManagedTensor *edge_mapping_managed_tensor = new DLManagedTensor();
edge_mapping_managed_tensor->dl_tensor = edge_mapping_tensor;
nf->edge_mapping = NDArray::FromDLPack(edge_mapping_managed_tensor);
// layer_offset
Message array_2;
CHECK_NE(receiver->RecvFrom(&array_2, send_id), -1);
CHECK_EQ(msg.data_shape_[4], 1);
DLTensor layer_offsets_tensor;
ConstructNFTensor(&layer_offsets_tensor, array_2.data, msg.data_shape_[5]);
DLManagedTensor *layer_offsets_managed_tensor = new DLManagedTensor();
layer_offsets_managed_tensor->dl_tensor = layer_offsets_tensor;
nf->layer_offsets = NDArray::FromDLPack(layer_offsets_managed_tensor);
// flow_offset
Message array_3;
CHECK_NE(receiver->RecvFrom(&array_3, send_id), -1);
CHECK_EQ(msg.data_shape_[6], 1);
DLTensor flow_offsets_tensor;
ConstructNFTensor(&flow_offsets_tensor, array_3.data, msg.data_shape_[7]);
DLManagedTensor *flow_offsets_managed_tensor = new DLManagedTensor();
flow_offsets_managed_tensor->dl_tensor = flow_offsets_tensor;
nf->flow_offsets = NDArray::FromDLPack(flow_offsets_managed_tensor);
// CSR indptr
Message array_4;
CHECK_NE(receiver->RecvFrom(&array_4, send_id), -1);
CHECK_EQ(msg.data_shape_[8], 1);
DLTensor indptr_tensor;
ConstructNFTensor(&indptr_tensor, array_4.data, msg.data_shape_[9]);
DLManagedTensor *indptr_managed_tensor = new DLManagedTensor();
indptr_managed_tensor->dl_tensor = indptr_tensor;
NDArray indptr = NDArray::FromDLPack(indptr_managed_tensor);
// CSR indice
Message array_5;
CHECK_NE(receiver->RecvFrom(&array_5, send_id), -1);
CHECK_EQ(msg.data_shape_[10], 1);
DLTensor indice_tensor;
ConstructNFTensor(&indice_tensor, array_5.data, msg.data_shape_[11]);
DLManagedTensor *indice_managed_tensor = new DLManagedTensor();
indice_managed_tensor->dl_tensor = indice_tensor;
NDArray indice = NDArray::FromDLPack(indice_managed_tensor);
// CSR edge_ids
Message array_6;
CHECK_NE(receiver->RecvFrom(&array_6, send_id), -1);
CHECK_EQ(msg.data_shape_[12], 1);
DLTensor edge_id_tensor;
ConstructNFTensor(&edge_id_tensor, array_6.data, msg.data_shape_[13]);
DLManagedTensor *edge_id_managed_tensor = new DLManagedTensor();
edge_id_managed_tensor->dl_tensor = edge_id_tensor;
NDArray edge_ids = NDArray::FromDLPack(edge_id_managed_tensor);
// Create CSR
CSRPtr csr(new CSR(indptr, indice, edge_ids));
nf->graph = GraphPtr(new ImmutableGraph(csr, nullptr));
List<NodeFlow> subgs;
subgs.push_back(nf);
*rv = subgs;
} else if (control == CONTROL_END_SIGNAL) {
*rv = CONTROL_END_SIGNAL;
*rv = nf;
} else if (msg.msg_type() == kEndMsg) {
*rv = msg.msg_type();
} else {
LOG(FATAL) << "Unknow control number: " << control;
LOG(FATAL) << "Unknown message type: " << msg.msg_type();
}
});
......
......@@ -7,25 +7,117 @@
#define DGL_GRAPH_NETWORK_H_
#include <dmlc/logging.h>
#include <dgl/runtime/ndarray.h>
#include <string.h>
#include <vector>
#include "../c_api_common.h"
#include "./network/msg_queue.h"
using dgl::runtime::NDArray;
namespace dgl {
namespace network {
#define IS_SENDER true
#define IS_RECEIVER false
// Max size of message queue for communicator is 200 MB
// TODO(chao): Make this number configurable
const int64_t kQueueSize = 200 * 1024 * 1024;
/*!
* \brief Free memory buffer of NodeFlow
*/
inline void NDArrayDeleter(Message* msg) {
delete reinterpret_cast<NDArray*>(msg->aux_handler);
}
/*!
* \brief Message type for DGL distributed training
*/
enum MessageType {
/*!
* \brief Message for send/recv NodeFlow
*/
kNodeFlowMsg = 0,
/*!
* \brief Message for end-signal
*/
kEndMsg = 1
};
/*!
* \brief Meta data for communicator message
*/
class MsgMeta {
public:
/*!
* \brief MsgMeta constructor.
* \param msg_type type of message
*/
explicit MsgMeta(int msg_type)
: msg_type_(msg_type), ndarray_count_(0) {}
/*!
* \brief Construct MsgMeta from binary data buffer.
* \param buffer data buffer
* \param size data size
*/
MsgMeta(char* buffer, int64_t size) {
CHECK_NOTNULL(buffer);
this->Deserialize(buffer, size);
}
/*!
* \return message type
*/
inline int msg_type() const {
return msg_type_;
}
/*!
* \return count of ndarray
*/
inline int ndarray_count() const {
return ndarray_count_;
}
/*!
* \brief Add NDArray meta data to MsgMeta
* \param array DGL NDArray
*/
void AddArray(const NDArray& array);
/*!
* \brief Serialize MsgMeta to data buffer
* \param size size of serialized message
* \return pointer of data buffer
*/
char* Serialize(int64_t* size);
/*!
* \brief Deserialize MsgMeta from data buffer
* \param buffer data buffer
* \param size size of data buffer
*/
void Deserialize(char* buffer, int64_t size);
/*!
* \brief type of message
*/
int msg_type_;
// TODO(chao): make these numbers configurable
/*!
* \brief count of ndarray in MetaMsg
*/
int ndarray_count_;
// Each single message cannot larger than 300 MB
const int64_t kMaxBufferSize = 300 * 1024 * 2014;
// Size of message queue is 1 GB
const int64_t kQueueSize = 1024 * 1024 * 1024;
// Maximal try count of connection
const int kMaxTryCount = 500;
/*!
* \brief We first write the ndim to data_shape_
* and then write the data shape.
*/
std::vector<int64_t> data_shape_;
};
// Control number
const int CONTROL_NODEFLOW = 0;
const int CONTROL_END_SIGNAL = 1;
} // namespace network
} // namespace dgl
......
/*!
* Copyright (c) 2019 by Contributors
* \file common.cc
* \brief This file provide basic facilities for string
* to make programming convenient.
*/
#include "common.h"
#include <stdarg.h>
#include <stdio.h>
using std::string;
namespace dgl {
namespace network {
// In most cases, delim contains only one character. In this case, we
// use CalculateReserveForVector to count the number of elements should
// be reserved in result vector, and thus optimize SplitStringUsing.
static int CalculateReserveForVector(const std::string& full, const char* delim) {
int count = 0;
if (delim[0] != '\0' && delim[1] == '\0') {
// Optimize the common case where delim is a single character.
char c = delim[0];
const char* p = full.data();
const char* end = p + full.size();
while (p != end) {
if (*p == c) { // This could be optimized with hasless(v,1) trick.
++p;
} else {
while (++p != end && *p != c) {
// Skip to the next occurence of the delimiter.
}
++count;
}
}
}
return count;
}
void SplitStringUsing(const std::string& full,
const char* delim,
std::vector<std::string>* result) {
CHECK(delim != NULL);
CHECK(result != NULL);
result->reserve(CalculateReserveForVector(full, delim));
back_insert_iterator< std::vector<std::string> > it(*result);
SplitStringToIteratorUsing(full, delim, &it);
}
void SplitStringToSetUsing(const std::string& full,
const char* delim,
std::set<std::string>* result) {
CHECK(delim != NULL);
CHECK(result != NULL);
simple_insert_iterator<std::set<std::string> > it(result);
SplitStringToIteratorUsing(full, delim, &it);
}
static void StringAppendV(string* dst, const char* format, va_list ap) {
// First try with a small fixed size buffer
char space[1024];
// It's possible for methods that use a va_list to invalidate
// the data in it upon use. The fix is to make a copy
// of the structure before using it and use that copy instead.
va_list backup_ap;
va_copy(backup_ap, ap);
int result = vsnprintf(space, sizeof(space), format, backup_ap);
va_end(backup_ap);
if ((result >= 0) && (result < sizeof(space))) {
// It fit
dst->append(space, result);
return;
}
// Repeatedly increase buffer size until it fits
int length = sizeof(space);
while (true) {
if (result < 0) {
// Older behavior: just try doubling the buffer size
length *= 2;
} else {
// We need exactly "result+1" characters
length = result + 1;
}
char* buf = new char[length];
// Restore the va_list before we use it again
va_copy(backup_ap, ap);
result = vsnprintf(buf, length, format, backup_ap);
va_end(backup_ap);
if ((result >= 0) && (result < length)) {
// It fit
dst->append(buf, result);
delete[] buf;
return;
}
delete[] buf;
}
}
string StringPrintf(const char* format, ...) {
va_list ap;
va_start(ap, format);
string result;
StringAppendV(&result, format, ap);
va_end(ap);
return result;
}
void SStringPrintf(string* dst, const char* format, ...) {
va_list ap;
va_start(ap, format);
dst->clear();
StringAppendV(dst, format, ap);
va_end(ap);
}
void StringAppendF(string* dst, const char* format, ...) {
va_list ap;
va_start(ap, format);
StringAppendV(dst, format, ap);
va_end(ap);
}
} // namespace network
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file common.h
* \brief This file provide basic facilities for string
* to make programming convenient.
*/
#ifndef DGL_GRAPH_NETWORK_COMMON_H_
#define DGL_GRAPH_NETWORK_COMMON_H_
#include <dmlc/logging.h>
#include <set>
#include <string>
#include <vector>
namespace dgl {
namespace network {
//------------------------------------------------------------------------------
// Subdivide string |full| into substrings according to delimitors
// given in |delim|. |delim| should pointing to a string including
// one or more characters. Each character is considerred a possible
// delimitor. For example:
//
// vector<string> substrings;
// SplitStringUsing("apple orange\tbanana", "\t ", &substrings);
//
// results in three substrings:
//
// substrings.size() == 3
// substrings[0] == "apple"
// substrings[1] == "orange"
// substrings[2] == "banana"
//------------------------------------------------------------------------------
void SplitStringUsing(const std::string& full,
const char* delim,
std::vector<std::string>* result);
// This function has the same semnatic as SplitStringUsing. Results
// are saved in an STL set container.
void SplitStringToSetUsing(const std::string& full,
const char* delim,
std::set<std::string>* result);
template <typename T>
struct simple_insert_iterator {
explicit simple_insert_iterator(T* t) : t_(t) { }
simple_insert_iterator<T>& operator=(const typename T::value_type& value) {
t_->insert(value);
return *this;
}
simple_insert_iterator<T>& operator*() { return *this; }
simple_insert_iterator<T>& operator++() { return *this; }
simple_insert_iterator<T>& operator++(int placeholder) { return *this; }
T* t_;
};
template <typename T>
struct back_insert_iterator {
explicit back_insert_iterator(T& t) : t_(t) {}
back_insert_iterator<T>& operator=(const typename T::value_type& value) {
t_.push_back(value);
return *this;
}
back_insert_iterator<T>& operator*() { return *this; }
back_insert_iterator<T>& operator++() { return *this; }
back_insert_iterator<T> operator++(int placeholder) { return *this; }
T& t_;
};
template <typename StringType, typename ITR>
static inline
void SplitStringToIteratorUsing(const StringType& full,
const char* delim,
ITR* result) {
CHECK_NOTNULL(delim);
// Optimize the common case where delim is a single character.
if (delim[0] != '\0' && delim[1] == '\0') {
char c = delim[0];
const char* p = full.data();
const char* end = p + full.size();
while (p != end) {
if (*p == c) {
++p;
} else {
const char* start = p;
while (++p != end && *p != c) {
// Skip to the next occurence of the delimiter.
}
*(*result)++ = StringType(start, p - start);
}
}
return;
}
std::string::size_type begin_index, end_index;
begin_index = full.find_first_not_of(delim);
while (begin_index != std::string::npos) {
end_index = full.find_first_of(delim, begin_index);
if (end_index == std::string::npos) {
*(*result)++ = full.substr(begin_index);
return;
}
*(*result)++ = full.substr(begin_index, (end_index - begin_index));
begin_index = full.find_first_not_of(delim, end_index);
}
}
//------------------------------------------------------------------------------
// StringPrintf:
//
// For example:
//
// std::string str = StringPrintf("%d", 1); /* str = "1" */
// SStringPrintf(&str, "%d", 2); /* str = "2" */
// StringAppendF(&str, "%d", 3); /* str = "23" */
//------------------------------------------------------------------------------
std::string StringPrintf(const char* format, ...);
void SStringPrintf(std::string* dst, const char* format, ...);
void StringAppendF(std::string* dst, const char* format, ...);
} // namespace network
} // namespace dgl
#endif // DGL_GRAPH_NETWORK_COMMON_H_
......@@ -6,112 +6,165 @@
#ifndef DGL_GRAPH_NETWORK_COMMUNICATOR_H_
#define DGL_GRAPH_NETWORK_COMMUNICATOR_H_
#include <dmlc/logging.h>
#include <string>
#include "msg_queue.h"
namespace dgl {
namespace network {
/*!
* \brief Network Sender for DGL distributed training.
*
* Sender is an abstract class that defines a set of APIs for sending
* binary data over network. It can be implemented by different underlying
* networking libraries such TCP socket and ZMQ. One Sender can connect to
* multiple receivers, and it can send data to specified receiver via receiver's ID.
* Sender is an abstract class that defines a set of APIs for sending binary
* data message over network. It can be implemented by different underlying
* networking libraries such TCP socket and MPI. One Sender can connect to
* multiple receivers and it can send data to specified receiver via receiver's ID.
*/
class Sender {
public:
/*!
* \brief Sender constructor
* \param queue_size size (bytes) of message queue.
* Note that, the queue_size parameter is optional.
*/
explicit Sender(int64_t queue_size = 0) {
CHECK_GE(queue_size, 0);
queue_size_ = queue_size;
}
virtual ~Sender() {}
/*!
* \brief Add receiver address and it's ID to the namebook
* \param ip receviver's IP address
* \param port receiver's port
* \brief Add receiver's address and ID to the sender's namebook
* \param addr Networking address, e.g., 'socket://127.0.0.1:50091', 'mpi://0'
* \param id receiver's ID
*
* AddReceiver() is not thread-safe and only one thread can invoke this API.
*/
virtual void AddReceiver(const char* ip, int port, int id) = 0;
virtual void AddReceiver(const char* addr, int id) = 0;
/*!
* \brief Connect with all the Receivers
* \return True for sucess and False for fail
* \return True for success and False for fail
*
* Connect() is not thread-safe and only one thread can invoke this API.
*/
virtual bool Connect() = 0;
/*!
* \brief Send data to specified Receiver
* \param data data buffer for sending
* \param size data size for sending
* \brief Send data to specified Receiver.
* \param msg data message
* \param recv_id receiver's ID
* \return bytes we sent
* > 0 : bytes we sent
* - 1 : error
* \return Status code
*
* (1) The send is non-blocking. There is no guarantee that the message has been
* physically sent out when the function returns.
* (2) The communicator will assume the responsibility of the given message.
* (3) The API is multi-thread safe.
* (4) Messages sent to the same receiver are guaranteed to be received in the same order.
* There is no guarantee for messages sent to different receivers.
*/
virtual int64_t Send(const char* data, int64_t size, int recv_id) = 0;
virtual STATUS Send(Message msg, int recv_id) = 0;
/*!
* \brief Finalize Sender
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
virtual void Finalize() = 0;
/*!
* \brief Get data buffer
* \return buffer pointer
* \brief Communicator type: 'socket', 'mpi', etc.
*/
virtual char* GetBuffer() = 0;
virtual std::string Type() const = 0;
protected:
/*!
* \brief Set data buffer
* \brief Size of message queue
*/
virtual void SetBuffer(char* buffer) = 0;
int64_t queue_size_;
};
/*!
* \brief Network Receiver for DGL distributed training.
*
* Receiver is an abstract class that defines a set of APIs for receiving binary
* data over network. It can be implemented by different underlying networking libraries
* such TCP socket and ZMQ. One Receiver can connect with multiple Senders, and it can receive
* data from these Senders concurrently via multi-threading and message queue.
* Receiver is an abstract class that defines a set of APIs for receiving binary data
* message over network. It can be implemented by different underlying networking
* libraries such as TCP socket and MPI. One Receiver can connect with multiple Senders
* and it can receive data from multiple Senders concurrently.
*/
class Receiver {
public:
/*!
* \brief Receiver constructor
* \param queue_size size of message queue.
* Note that, the queue_size parameter is optional.
*/
explicit Receiver(int64_t queue_size = 0) {
if (queue_size < 0) {
LOG(FATAL) << "queue_size cannot be a negative number.";
}
queue_size_ = queue_size;
}
virtual ~Receiver() {}
/*!
* \brief Wait all of the Senders to connect
* \param ip Receiver's IP address
* \param port Receiver's port
* \brief Wait for all the Senders to connect
* \param addr Networking address, e.g., 'socket://127.0.0.1:50051', 'mpi://0'
* \param num_sender total number of Senders
* \param queue_size size of message queue
* \return True for sucess and False for fail
* \return True for success and False for fail
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
virtual bool Wait(const char* addr, int num_sender) = 0;
/*!
* \brief Recv data from Sender
* \param msg pointer of data message
* \param send_id which sender current msg comes from
* \return Status code
*
* (1) The Recv() API is blocking, which will not
* return until getting data from message queue.
* (2) The Recv() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/
virtual bool Wait(const char* ip, int port, int num_sender, int queue_size) = 0;
virtual STATUS Recv(Message* msg, int* send_id) = 0;
/*!
* \brief Recv data from Sender (copy data from message queue)
* \param dest data buffer of destination
* \param max_size maximul size of data buffer
* \return bytes we received
* > 0 : bytes we received
* - 1 : error
* \brief Recv data from a specified Sender
* \param msg pointer of data message
* \param send_id sender's ID
* \return Status code
*
* (1) The RecvFrom() API is blocking, which will not
* return until getting data from message queue.
* (2) The RecvFrom() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/
virtual int64_t Recv(char* dest, int64_t max_size) = 0;
virtual STATUS RecvFrom(Message* msg, int send_id) = 0;
/*!
* \brief Finalize Receiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
virtual void Finalize() = 0;
/*!
* \brief Get data buffer
* \return buffer pointer
* \brief Communicator type: 'socket', 'mpi', etc
*/
virtual char* GetBuffer() = 0;
virtual std::string Type() const = 0;
protected:
/*!
* \brief Set data buffer
* \brief Size of message queue
*/
virtual void SetBuffer(char* buffer) = 0;
int64_t queue_size_;
};
} // namespace network
......
......@@ -14,162 +14,75 @@ namespace network {
using std::string;
MessageQueue::MessageQueue(int64_t queue_size, int num_producers) {
CHECK_LT(0, queue_size);
try {
queue_ = new char[queue_size];
} catch(const std::bad_alloc&) {
LOG(FATAL) << "Not enough memory for message queue.";
}
memset(queue_, '\0', queue_size);
CHECK_GE(queue_size, 0);
CHECK_GE(num_producers, 0);
queue_size_ = queue_size;
free_size_ = queue_size;
write_pointer_ = 0;
num_producers_ = num_producers;
}
MessageQueue::~MessageQueue() {
std::lock_guard<std::mutex> lock(mutex_);
if (nullptr != queue_) {
delete [] queue_;
queue_ = nullptr;
}
}
int64_t MessageQueue::Add(const char* src, int64_t size, bool is_blocking) {
STATUS MessageQueue::Add(Message msg, bool is_blocking) {
// check if message is too long to fit into the queue
if (size > queue_size_) {
LOG(ERROR) << "Message is larger than the queue.";
return -1;
if (msg.size > queue_size_) {
LOG(WARNING) << "Message is larger than the queue.";
return MSG_GT_SIZE;
}
if (size <= 0) {
LOG(ERROR) << "Message size (" << size << ") is negative or zero.";
return -1;
if (msg.size <= 0) {
LOG(WARNING) << "Message size (" << msg.size << ") is negative or zero.";
return MSG_LE_ZERO;
}
std::unique_lock<std::mutex> lock(mutex_);
if (finished_producers_.size() >= num_producers_) {
LOG(ERROR) << "Can't add to buffer when flag_no_more is set.";
return -1;
LOG(WARNING) << "Message queue is closed.";
return QUEUE_CLOSE;
}
if (size > free_size_ && !is_blocking) {
LOG(WARNING) << "Queue is full and message lost.";
return 0;
if (msg.size > free_size_ && !is_blocking) {
return QUEUE_FULL;
}
cond_not_full_.wait(lock, [&]() {
return size <= free_size_;
return msg.size <= free_size_;
});
// Write data into buffer:
// If there has enough space on tail of buffer, just append data
// else, write till in the end of buffer and return to head of buffer
message_positions_.push(std::make_pair(write_pointer_, size));
free_size_ -= size;
if (write_pointer_ + size <= queue_size_) {
memcpy(&queue_[write_pointer_], src, size);
write_pointer_ += size;
if (write_pointer_ == queue_size_) {
write_pointer_ = 0;
}
} else {
int64_t size_partial = queue_size_ - write_pointer_;
memcpy(&queue_[write_pointer_], src, size_partial);
memcpy(queue_, &src[size_partial], size - size_partial);
write_pointer_ = size - size_partial;
}
// Add data pointer to queue
queue_.push(msg);
free_size_ -= msg.size;
// not empty signal
cond_not_empty_.notify_one();
return size;
}
int64_t MessageQueue::Add(const string &src, bool is_blocking) {
return Add(src.data(), src.size(), is_blocking);
}
int64_t MessageQueue::Remove(char *dest, int64_t max_size, bool is_blocking) {
int64_t retval;
std::unique_lock<std::mutex> lock(mutex_);
if (message_positions_.empty()) {
if (!is_blocking) {
return 0;
}
if (finished_producers_.size() >= num_producers_) {
return 0;
}
}
cond_not_empty_.wait(lock, [this] {
return !message_positions_.empty() || exit_flag_.load();
});
if (finished_producers_.size() >= num_producers_) {
return 0;
}
MessagePosition & pos = message_positions_.front();
// check if message is too long
if (pos.second > max_size) {
LOG(ERROR) << "Message size exceeds limit, information lost.";
retval = -1;
} else {
// read from buffer:
// if this message stores in consecutive memory, just read
// else, read from buffer tail then return to the head
if (pos.first + pos.second <= queue_size_) {
memcpy(dest, &queue_[pos.first], pos.second);
} else {
int64_t size_partial = queue_size_ - pos.first;
memcpy(dest, &queue_[pos.first], size_partial);
memcpy(&dest[size_partial], queue_, pos.second - size_partial);
}
retval = pos.second;
}
free_size_ += pos.second;
message_positions_.pop();
cond_not_full_.notify_one();
return retval;
return ADD_SUCCESS;
}
int64_t MessageQueue::Remove(string *dest, bool is_blocking) {
int64_t retval;
STATUS MessageQueue::Remove(Message* msg, bool is_blocking) {
std::unique_lock<std::mutex> lock(mutex_);
if (message_positions_.empty()) {
if (queue_.empty()) {
if (!is_blocking) {
return 0;
return QUEUE_EMPTY;
}
if (finished_producers_.size() >= num_producers_) {
return 0;
LOG(WARNING) << "Message queue is closed.";
return QUEUE_CLOSE;
}
}
cond_not_empty_.wait(lock, [this] {
return !message_positions_.empty() || exit_flag_.load();
return !queue_.empty() || exit_flag_.load();
});
MessagePosition & pos = message_positions_.front();
// read from buffer:
// if this message stores in consecutive memory, just read
// else, read from buffer tail then return to the head
if (pos.first + pos.second <= queue_size_) {
dest->assign(&queue_[pos.first], pos.second);
} else {
int64_t size_partial = queue_size_ - pos.first;
dest->assign(&queue_[pos.first], size_partial);
dest->append(queue_, pos.second - size_partial);
if (finished_producers_.size() >= num_producers_ && queue_.empty()) {
LOG(WARNING) << "Message queue is closed.";
return QUEUE_CLOSE;
}
retval = pos.second;
free_size_ += pos.second;
message_positions_.pop();
Message & old_msg = queue_.front();
queue_.pop();
msg->data = old_msg.data;
msg->size = old_msg.size;
msg->deallocator = old_msg.deallocator;
free_size_ += old_msg.size;
cond_not_full_.notify_one();
return retval;
return REMOVE_SUCCESS;
}
void MessageQueue::Signal(int producer_id) {
void MessageQueue::SignalFinished(int producer_id) {
std::lock_guard<std::mutex> lock(mutex_);
finished_producers_.insert(producer_id);
// if all producers have finished, consumers should be
......@@ -180,9 +93,14 @@ void MessageQueue::Signal(int producer_id) {
}
}
bool MessageQueue::Empty() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size() == 0;
}
bool MessageQueue::EmptyAndNoMoreAdd() const {
std::lock_guard<std::mutex> lock(mutex_);
return message_positions_.size() == 0 &&
return queue_.size() == 0 &&
finished_producers_.size() >= num_producers_;
}
......
......@@ -13,30 +13,82 @@
#include <mutex>
#include <condition_variable>
#include <atomic>
#include <functional>
namespace dgl {
namespace network {
typedef int STATUS;
/*!
* \brief Status code of message queue
*/
#define ADD_SUCCESS 3400 // Add message successfully
#define MSG_GT_SIZE 3401 // Message size beyond queue size
#define MSG_LE_ZERO 3402 // Message size is not a positive number
#define QUEUE_CLOSE 3403 // Cannot add message when queue is closed
#define QUEUE_FULL 3404 // Cannot add message when queue is full
#define REMOVE_SUCCESS 3405 // Remove message successfully
#define QUEUE_EMPTY 3406 // Cannot remove when queue is empty
/*!
* \brief Message used by network communicator and message queue.
*/
struct Message {
/*!
* \brief Constructor
*/
Message() { }
/*!
* \brief Constructor
*/
Message(char* data_ptr, int64_t data_size)
: data(data_ptr), size(data_size) { }
/*!
* \brief message data
*/
char* data;
/*!
* \brief message size in bytes
*/
int64_t size;
/*!
* \brief aux_data pointer handler
*/
void* aux_handler;
/*!
* \brief user-defined deallocator, which can be nullptr
*/
std::function<void(Message*)> deallocator = nullptr;
};
/*!
* \brief Message Queue for DGL distributed training.
* \brief Free memory buffer of message
*/
inline void DefaultMessageDeleter(Message* msg) { delete [] msg->data; }
/*!
* \brief Message Queue for network communication.
*
* MessageQueue is FIFO queue that adopts producer/consumer model for data message.
* It supports one or more producer threads and one or more consumer threads.
* Producers invokes Add() to push data message into the queue, and consumers
* invokes Remove() to pop data message from queue. Add() and Remove() use two condition
* variables to synchronize producer threads and consumer threads. Each producer
* invokes SignalFinished(producer_id) to claim that it is about to finish, where
* producer_id is an integer uniquely identify a producer thread. This signaling mechanism
* prevents consumers from waiting after all producers have finished their jobs.
*
* MessageQueue is a circle queue for using the ring-buffer in a
* producer/consumer model. It supports one or more producer
* threads and one or more consumer threads. Producers invokes Add()
* to push data elements into the queue, and consumers invokes
* Remove() to pop data elements. Add() and Remove() use two condition
* variables to synchronize producers and consumers. Each producer invokes
* Signal(producer_id) to claim that it is about to finish, where
* producer_id is an integer uniquely identify a producer thread. This
* signaling mechanism prevents consumers from waiting after all producers
* have finished their jobs.
* MessageQueue is thread-safe.
*
*/
class MessageQueue {
public:
/*!
* \brief MessageQueue constructor
* \param queue_size size of message queue
* \param queue_size size (bytes) of message queue
* \param num_producers number of producers, use 1 by default
*/
MessageQueue(int64_t queue_size /* in bytes */,
......@@ -45,59 +97,34 @@ class MessageQueue {
/*!
* \brief MessageQueue deconstructor
*/
~MessageQueue();
~MessageQueue() {}
/*!
* \brief Add data to the message queue
* \param src The data pointer
* \param size The size of data
* \param is_blocking Block function if cannot add, else return
* \return bytes added to the queue
* > 0 : size of message
* = 0 : no enough space for this message (when is_blocking = false)
* - 1 : error
* \brief Add message to the queue
* \param msg data message
* \param is_blocking Blocking if cannot add, else return
* \return Status code
*/
int64_t Add(const char* src, int64_t size, bool is_blocking = true);
/*!
* \brief Add data to the message queue
* \param src The data string
* \param is_blocking Block function if cannot add, else return
* \return bytes added to queue
* > 0 : size of message
* = 0 : no enough space for this message (when is_blocking = false)
* - 1 : error
*/
int64_t Add(const std::string& src, bool is_blocking = true);
STATUS Add(Message msg, bool is_blocking = true);
/*!
* \brief Remove message from the queue
* \param dest The destination data pointer
* \param max_size Maximal size of data
* \param is_blocking Block function if cannot remove, else return
* \return bytes removed from queue
* > 0 : size of message
* = 0 : queue is empty
* - 1 : error
* \param msg pointer of data msg
* \param is_blocking Blocking if cannot remove, else return
* \return Status code
*/
int64_t Remove(char *dest, int64_t max_size, bool is_blocking = true);
STATUS Remove(Message* msg, bool is_blocking = true);
/*!
* \brief Remove message from the queue
* \param dest The destination data string
* \param is_blocking Block function if cannot remove, else return
* \return bytes removed from queue
* > 0 : size of message
* = 0 : queue is empty
* - 1 : error
* \brief Signal that producer producer_id will no longer produce anything
* \param producer_id An integer uniquely to identify a producer thread
*/
int64_t Remove(std::string *dest, bool is_blocking = true);
void SignalFinished(int producer_id);
/*!
* \brief Signal that producer producer_id will no longer produce anything
* \param producer_id An integer uniquely to identify a producer thread
* \return true if queue is empty.
*/
void Signal(int producer_id);
bool Empty() const;
/*!
* \return true if queue is empty and all num_producers have signaled.
......@@ -105,13 +132,10 @@ class MessageQueue {
bool EmptyAndNoMoreAdd() const;
protected:
typedef std::pair<int64_t /* message_start_position in queue_ */,
int64_t /* message_length */> MessagePosition;
/*!
* \brief Pointer to the queue
* \brief message queue
*/
char* queue_;
std::queue<Message> queue_;
/*!
* \brief Size of the queue in bytes
......@@ -123,24 +147,11 @@ class MessageQueue {
*/
int64_t free_size_;
/*!
* \brief Location in queue_ for where to write the next element
* Note that we do not need read_pointer since all messages were indexed
* by message_postions_, and the first element in message_position_
* denotes where we should read
*/
int64_t write_pointer_;
/*!
* \brief Used to check all producers will no longer produce anything
*/
size_t num_producers_;
/*!
* \brief Messages in the queue
*/
std::queue<MessagePosition> message_positions_;
/*!
* \brief Store finished producer id
*/
......
/*!
* Copyright (c) 2019 by Contributors
* \file serialize.cc
* \brief Serialization for DGL distributed training.
*/
#include "serialize.h"
#include <dmlc/logging.h>
#include <dgl/immutable_graph.h>
#include <cstring>
#include "../network.h"
namespace dgl {
namespace network {
const int kNumTensor = 7; // We need to serialize 7 conponents (tensor) here
int64_t SerializeSampledSubgraph(char* data,
const CSRPtr csr,
const IdArray& node_mapping,
const IdArray& edge_mapping,
const IdArray& layer_offsets,
const IdArray& flow_offsets) {
int64_t total_size = 0;
// For each component, we first write its size at the
// begining of the buffer and then write its binary data
int64_t node_mapping_size = node_mapping->shape[0] * sizeof(dgl_id_t);
int64_t edge_mapping_size = edge_mapping->shape[0] * sizeof(dgl_id_t);
int64_t layer_offsets_size = layer_offsets->shape[0] * sizeof(dgl_id_t);
int64_t flow_offsets_size = flow_offsets->shape[0] * sizeof(dgl_id_t);
int64_t indptr_size = csr->indptr().GetSize();
int64_t indices_size = csr->indices().GetSize();
int64_t edge_ids_size = csr->edge_ids().GetSize();
total_size += node_mapping_size;
total_size += edge_mapping_size;
total_size += layer_offsets_size;
total_size += flow_offsets_size;
total_size += indptr_size;
total_size += indices_size;
total_size += edge_ids_size;
total_size += kNumTensor * sizeof(int64_t);
if (total_size > kMaxBufferSize) {
LOG(FATAL) << "Message size: (" << total_size
<< ") is larger than buffer size: ("
<< kMaxBufferSize << ")";
}
// Write binary data to buffer
char* data_ptr = data;
dgl_id_t* node_map_data = static_cast<dgl_id_t*>(node_mapping->data);
dgl_id_t* edge_map_data = static_cast<dgl_id_t*>(edge_mapping->data);
dgl_id_t* layer_off_data = static_cast<dgl_id_t*>(layer_offsets->data);
dgl_id_t* flow_off_data = static_cast<dgl_id_t*>(flow_offsets->data);
dgl_id_t* indptr = static_cast<dgl_id_t*>(csr->indptr()->data);
dgl_id_t* indices = static_cast<dgl_id_t*>(csr->indices()->data);
dgl_id_t* edge_ids = static_cast<dgl_id_t*>(csr->edge_ids()->data);
// node_mapping
*(reinterpret_cast<int64_t*>(data_ptr)) = node_mapping_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, node_map_data, node_mapping_size);
data_ptr += node_mapping_size;
// layer_offsets
*(reinterpret_cast<int64_t*>(data_ptr)) = layer_offsets_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, layer_off_data, layer_offsets_size);
data_ptr += layer_offsets_size;
// flow_offsets
*(reinterpret_cast<int64_t*>(data_ptr)) = flow_offsets_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, flow_off_data, flow_offsets_size);
data_ptr += flow_offsets_size;
// edge_mapping
*(reinterpret_cast<int64_t*>(data_ptr)) = edge_mapping_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, edge_map_data, edge_mapping_size);
data_ptr += edge_mapping_size;
// indices (CSR)
*(reinterpret_cast<int64_t*>(data_ptr)) = indices_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, indices, indices_size);
data_ptr += indices_size;
// edge_ids (CSR)
*(reinterpret_cast<int64_t*>(data_ptr)) = edge_ids_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, edge_ids, edge_ids_size);
data_ptr += edge_ids_size;
// indptr (CSR)
*(reinterpret_cast<int64_t*>(data_ptr)) = indptr_size;
data_ptr += sizeof(int64_t);
memcpy(data_ptr, indptr, indptr_size);
data_ptr += indptr_size;
return total_size;
}
void DeserializeSampledSubgraph(char* data,
CSRPtr* csr,
IdArray* node_mapping,
IdArray* edge_mapping,
IdArray* layer_offsets,
IdArray* flow_offsets) {
// For each component, we first read its size at the
// begining of the buffer and then read its binary data
char* data_ptr = data;
// node_mapping
int64_t tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
int64_t num_vertices = tensor_size / sizeof(int64_t);
data_ptr += sizeof(int64_t);
*node_mapping = IdArray::Empty({static_cast<int64_t>(num_vertices)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* node_map_data = static_cast<dgl_id_t*>((*node_mapping)->data);
memcpy(node_map_data, data_ptr, tensor_size);
data_ptr += tensor_size;
// layer offsets
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
int64_t num_hops_add_one = tensor_size / sizeof(int64_t);
data_ptr += sizeof(int64_t);
*layer_offsets = IdArray::Empty({static_cast<int64_t>(num_hops_add_one)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* layer_off_data = static_cast<dgl_id_t*>((*layer_offsets)->data);
memcpy(layer_off_data, data_ptr, tensor_size);
data_ptr += tensor_size;
// flow offsets
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
int64_t num_hops = tensor_size / sizeof(int64_t);
data_ptr += sizeof(int64_t);
*flow_offsets = IdArray::Empty({static_cast<int64_t>(num_hops)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* flow_off_data = static_cast<dgl_id_t*>((*flow_offsets)->data);
memcpy(flow_off_data, data_ptr, tensor_size);
data_ptr += tensor_size;
// edge_mapping
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
int64_t num_edges = tensor_size / sizeof(int64_t);
data_ptr += sizeof(int64_t);
*edge_mapping = IdArray::Empty({static_cast<int64_t>(num_edges)},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* edge_mapping_data = static_cast<dgl_id_t*>((*edge_mapping)->data);
memcpy(edge_mapping_data, data_ptr, tensor_size);
data_ptr += tensor_size;
// Construct sub_csr_graph
// TODO(minjie): multigraph flag
*csr = CSRPtr(new CSR(num_vertices, num_edges, false));
// indices (CSR)
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
data_ptr += sizeof(int64_t);
dgl_id_t* col_list_out = static_cast<dgl_id_t*>((*csr)->indices()->data);
memcpy(col_list_out, data_ptr, tensor_size);
data_ptr += tensor_size;
// edge_ids (CSR)
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
data_ptr += sizeof(int64_t);
dgl_id_t* edge_ids = static_cast<dgl_id_t*>((*csr)->edge_ids()->data);
memcpy(edge_ids, data_ptr, tensor_size);
data_ptr += tensor_size;
// indptr (CSR)
tensor_size = *(reinterpret_cast<int64_t*>(data_ptr));
data_ptr += sizeof(int64_t);
dgl_id_t* indptr_out = static_cast<dgl_id_t*>((*csr)->indptr()->data);
memcpy(indptr_out, data_ptr, tensor_size);
data_ptr += tensor_size;
}
} // namespace network
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file serialize.h
* \brief Serialization for DGL distributed training.
*/
#ifndef DGL_GRAPH_NETWORK_SERIALIZE_H_
#define DGL_GRAPH_NETWORK_SERIALIZE_H_
#include <dgl/sampler.h>
#include <dgl/immutable_graph.h>
namespace dgl {
namespace network {
/*!
* \brief Serialize sampled subgraph to binary data
* \param data pointer of data buffer
* \param csr subgraph csr
* \param node_mapping node mapping in NodeFlowIndex
* \param edge_mapping edge mapping in NodeFlowIndex
* \param layer_offsets layer offsets in NodeFlowIndex
* \param flow_offsets flow offsets in NodeFlowIndex
* \return the total size of the serialized binary data
*/
int64_t SerializeSampledSubgraph(char* data,
const CSRPtr csr,
const IdArray& node_mapping,
const IdArray& edge_mapping,
const IdArray& layer_offsets,
const IdArray& flow_offsets);
/*!
* \brief Deserialize sampled subgraph from binary data
* \param data pointer of data buffer
* \param csr subgraph csr
* \param node_mapping node mapping in NodeFlowIndex
* \param edge_mapping edge mapping in NodeFlowIndex
* \param layer_offsets layer offsets in NodeFlowIndex
* \param flow_offsets flow offsets in NodeFlowIndex
*/
void DeserializeSampledSubgraph(char* data,
CSRPtr* csr,
IdArray* node_mapping,
IdArray* edge_mapping,
IdArray* layer_offsets,
IdArray* flow_offsets);
// TODO(chao): we can add compression and decompression method here
} // namespace network
} // namespace dgl
#endif // DGL_GRAPH_NETWORK_SERIALIZE_H_
......@@ -5,9 +5,12 @@
*/
#include <dmlc/logging.h>
#include <string.h>
#include <stdlib.h>
#include <time.h>
#include "socket_communicator.h"
#include "../../c_api_common.h"
#include "../network.h"
#ifdef _WIN32
#include <windows.h>
......@@ -18,34 +21,55 @@
namespace dgl {
namespace network {
const int kTimeOut = 10; // 10 minutes for socket timeout
const int kMaxConnection = 1024; // 1024 maximal socket connection
void SocketSender::AddReceiver(const char* ip, int port, int recv_id) {
dgl::network::Addr addr;
addr.ip_.assign(const_cast<char*>(ip));
addr.port_ = port;
receiver_addr_map_[recv_id] = addr;
/////////////////////////////////////// SocketSender ///////////////////////////////////////////
void SocketSender::AddReceiver(const char* addr, int recv_id) {
CHECK_NOTNULL(addr);
if (recv_id < 0) {
LOG(FATAL) << "recv_id cannot be a negative number.";
}
std::vector<std::string> substring;
std::vector<std::string> ip_and_port;
SplitStringUsing(addr, "//", &substring);
// Check address format
if (substring[0] != "socket:" || substring.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. ";
}
// Get IP and port
SplitStringUsing(substring[1], ":", &ip_and_port);
if (ip_and_port.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. ";
}
IPAddr address;
address.ip = ip_and_port[0];
address.port = std::stoi(ip_and_port[1]);
receiver_addrs_[recv_id] = address;
msg_queue_[recv_id] = std::make_shared<MessageQueue>(queue_size_);
}
bool SocketSender::Connect() {
// Create N sockets for Receiver
for (const auto& r : receiver_addr_map_) {
for (const auto& r : receiver_addrs_) {
int ID = r.first;
socket_map_[ID] = new TCPSocket();
TCPSocket* client = socket_map_[ID];
sockets_[ID] = std::make_shared<TCPSocket>();
TCPSocket* client_socket = sockets_[ID].get();
bool bo = false;
int try_count = 0;
const char* ip = r.second.ip_.c_str();
int port = r.second.port_;
const char* ip = r.second.ip.c_str();
int port = r.second.port;
while (bo == false && try_count < kMaxTryCount) {
if (client->Connect(ip, port)) {
if (client_socket->Connect(ip, port)) {
LOG(INFO) << "Connected to Receiver: " << ip << ":" << port;
bo = true;
} else {
LOG(ERROR) << "Cannot connect to Receiver: " << ip << ":" << port
<< ", try again ...";
bo = false;
try_count++;
#ifdef _WIN32
Sleep(1);
......@@ -57,101 +81,200 @@ bool SocketSender::Connect() {
if (bo == false) {
return bo;
}
// Create a new thread for this socket connection
threads_[ID] = std::make_shared<std::thread>(
SendLoop,
client_socket,
msg_queue_[ID].get());
}
return true;
}
int64_t SocketSender::Send(const char* data, int64_t size, int recv_id) {
TCPSocket* client = socket_map_[recv_id];
// First sent the size of data
int64_t sent_bytes = 0;
while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - sent_bytes;
int64_t tmp = client->Send(
reinterpret_cast<char*>(&size)+sent_bytes,
max_len);
sent_bytes += tmp;
}
// Then send the data
sent_bytes = 0;
while (sent_bytes < size) {
int64_t max_len = size - sent_bytes;
int64_t tmp = client->Send(data+sent_bytes, max_len);
sent_bytes += tmp;
}
return size + sizeof(int64_t);
STATUS SocketSender::Send(Message msg, int recv_id) {
CHECK_NOTNULL(msg.data);
CHECK_GT(msg.size, 0);
CHECK_GE(recv_id, 0);
// Add data message to message queue
STATUS code = msg_queue_[recv_id]->Add(msg);
return code;
}
void SocketSender::Finalize() {
// Close all sockets
for (const auto& socket : socket_map_) {
TCPSocket* client = socket.second;
if (client != nullptr) {
client->Close();
delete client;
client = nullptr;
// Send a signal to tell the msg_queue to finish its job
for (auto& mq : msg_queue_) {
// wait until queue is empty
while (mq.second->Empty() == false) {
#ifdef _WIN32
// just loop
#else // !_WIN32
usleep(1000);
#endif // _WIN32
}
int ID = mq.first;
mq.second->SignalFinished(ID);
}
// Block main thread until all socket-threads finish their jobs
for (auto& thread : threads_) {
thread.second->join();
}
// Clear all sockets
for (auto& socket : sockets_) {
socket.second->Close();
}
delete buffer_;
}
char* SocketSender::GetBuffer() {
return buffer_;
void SocketSender::SendLoop(TCPSocket* socket, MessageQueue* queue) {
CHECK_NOTNULL(socket);
CHECK_NOTNULL(queue);
bool exit = false;
while (!exit) {
Message msg;
STATUS code = queue->Remove(&msg);
if (code == QUEUE_CLOSE) {
msg.size = 0; // send an end-signal to receiver
exit = true;
}
// First send the size
// If exit == true, we will send zero size to reciever
int64_t sent_bytes = 0;
while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - sent_bytes;
int64_t tmp = socket->Send(
reinterpret_cast<char*>(&msg.size)+sent_bytes,
max_len);
CHECK_NE(tmp, -1);
sent_bytes += tmp;
}
// Then send the data
sent_bytes = 0;
while (sent_bytes < msg.size) {
int64_t max_len = msg.size - sent_bytes;
int64_t tmp = socket->Send(msg.data+sent_bytes, max_len);
CHECK_NE(tmp, -1);
sent_bytes += tmp;
}
// delete msg
if (msg.deallocator != nullptr) {
msg.deallocator(&msg);
}
}
}
void SocketSender::SetBuffer(char* buffer) {
buffer_ = buffer;
}
/////////////////////////////////////// SocketReceiver ///////////////////////////////////////////
bool SocketReceiver::Wait(const char* ip,
int port,
int num_sender,
int queue_size) {
CHECK_GE(num_sender, 1);
CHECK_GT(queue_size, 0);
// Initialize message queue
bool SocketReceiver::Wait(const char* addr, int num_sender) {
CHECK_NOTNULL(addr);
CHECK_GT(num_sender, 0);
std::vector<std::string> substring;
std::vector<std::string> ip_and_port;
SplitStringUsing(addr, "//", &substring);
// Check address format
if (substring[0] != "socket:" || substring.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. ";
}
// Get IP and port
SplitStringUsing(substring[1], ":", &ip_and_port);
if (ip_and_port.size() != 2) {
LOG(FATAL) << "Incorrect address format:" << addr
<< " Please provide right address format, "
<< "e.g, 'socket://127.0.0.1:50051'. ";
}
std::string ip = ip_and_port[0];
int port = stoi(ip_and_port[1]);
// Initialize message queue for each connection
num_sender_ = num_sender;
queue_size_ = queue_size;
queue_ = new MessageQueue(queue_size_, num_sender_);
// Initialize socket, and socket_[0] is server socket
socket_.resize(num_sender_+1);
thread_.resize(num_sender_);
socket_[0] = new TCPSocket();
TCPSocket* server = socket_[0];
server->SetTimeout(kTimeOut * 60 * 1000); // millsec
for (int i = 0; i < num_sender_; ++i) {
msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
}
// Initialize socket and socket-thread
server_socket_ = new TCPSocket();
server_socket_->SetTimeout(kTimeOut * 60 * 1000); // millsec
// Bind socket
if (server->Bind(ip, port) == false) {
if (server_socket_->Bind(ip.c_str(), port) == false) {
LOG(FATAL) << "Cannot bind to " << ip << ":" << port;
return false;
}
LOG(INFO) << "Bind to " << ip << ":" << port;
// Listen
if (server->Listen(kMaxConnection) == false) {
if (server_socket_->Listen(kMaxConnection) == false) {
LOG(FATAL) << "Cannot listen on " << ip << ":" << port;
return false;
}
LOG(INFO) << "Listen on " << ip << ":" << port << ", wait sender connect ...";
// Accept all sender sockets
std::string accept_ip;
int accept_port;
for (int i = 1; i <= num_sender_; ++i) {
socket_[i] = new TCPSocket();
if (server->Accept(socket_[i], &accept_ip, &accept_port) == false) {
LOG(FATAL) << "Error on accept socket.";
for (int i = 0; i < num_sender_; ++i) {
sockets_[i] = std::make_shared<TCPSocket>();
if (server_socket_->Accept(sockets_[i].get(), &accept_ip, &accept_port) == false) {
LOG(WARNING) << "Error on accept socket.";
return false;
}
// create new thread for each socket
thread_[i-1] = new std::thread(MsgHandler, socket_[i], queue_, i-1);
threads_[i] = std::make_shared<std::thread>(
RecvLoop,
sockets_[i].get(),
msg_queue_[i].get());
LOG(INFO) << "Accept new sender: " << accept_ip << ":" << accept_port;
}
return true;
}
void SocketReceiver::MsgHandler(TCPSocket* socket, MessageQueue* queue, int id) {
char* buffer = new char[kMaxBufferSize];
STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
// loop until get a message
for (;;) {
for (auto& mq : msg_queue_) {
*send_id = mq.first;
// We use non-block remove here
STATUS code = msg_queue_[*send_id]->Remove(msg, false);
if (code == QUEUE_EMPTY) {
continue; // jump to the next queue
} else {
return code;
}
}
}
}
STATUS SocketReceiver::RecvFrom(Message* msg, int send_id) {
// Get message from specified message queue
STATUS code = msg_queue_[send_id]->Remove(msg);
return code;
}
void SocketReceiver::Finalize() {
// Send a signal to tell the message queue to finish its job
for (auto& mq : msg_queue_) {
// wait until queue is empty
while (mq.second->Empty() == false) {
#ifdef _WIN32
// just loop
#else // !_WIN32
usleep(1000);
#endif // _WIN32
}
int ID = mq.first;
mq.second->SignalFinished(ID);
}
// Block main thread until all socket-threads finish their jobs
for (auto& thread : threads_) {
thread.second->join();
}
// Clear all sockets
for (auto& socket : sockets_) {
socket.second->Close();
}
}
void SocketReceiver::RecvLoop(TCPSocket* socket, MessageQueue* queue) {
CHECK_NOTNULL(socket);
CHECK_NOTNULL(queue);
for (;;) {
// If main thread had finished its job
if (queue->EmptyAndNoMoreAdd()) {
return; // exit loop thread
}
// First recv the size
int64_t received_bytes = 0;
int64_t data_size = 0;
......@@ -160,54 +283,36 @@ void SocketReceiver::MsgHandler(TCPSocket* socket, MessageQueue* queue, int id)
int64_t tmp = socket->Receive(
reinterpret_cast<char*>(&data_size)+received_bytes,
max_len);
CHECK_NE(tmp, -1);
received_bytes += tmp;
}
// Data_size ==-99 is a special signal to tell
// the MsgHandler to exit the loop
if (data_size <= 0) {
queue->Signal(id);
break;
}
// Then recv the data
received_bytes = 0;
while (received_bytes < data_size) {
int64_t max_len = data_size - received_bytes;
int64_t tmp = socket->Receive(buffer+received_bytes, max_len);
received_bytes += tmp;
}
queue->Add(buffer, data_size);
}
delete [] buffer;
}
int64_t SocketReceiver::Recv(char* dest, int64_t max_size) {
// Get message from message queue
return queue_->Remove(dest, max_size);
}
void SocketReceiver::Finalize() {
for (int i = 0; i <= num_sender_; ++i) {
if (i != 0) { // write -99 signal to exit loop
int64_t data_size = -99;
queue_->Add(
reinterpret_cast<char*>(&data_size),
sizeof(int64_t));
}
if (socket_[i] != nullptr) {
socket_[i]->Close();
delete socket_[i];
socket_[i] = nullptr;
if (data_size < 0) {
LOG(FATAL) << "Recv data error (data_size: " << data_size << ")";
} else if (data_size == 0) {
// This is an end-signal sent by client
return;
} else {
char* buffer = nullptr;
try {
buffer = new char[data_size];
} catch(const std::bad_alloc&) {
LOG(FATAL) << "Cannot allocate enough memory for message, "
<< "(message size: " << data_size << ")";
}
received_bytes = 0;
while (received_bytes < data_size) {
int64_t max_len = data_size - received_bytes;
int64_t tmp = socket->Receive(buffer+received_bytes, max_len);
CHECK_NE(tmp, -1);
received_bytes += tmp;
}
Message msg;
msg.data = buffer;
msg.size = data_size;
msg.deallocator = DefaultMessageDeleter;
queue->Add(msg);
}
}
delete buffer_;
}
char* SocketReceiver::GetBuffer() {
return buffer_;
}
void SocketReceiver::SetBuffer(char* buffer) {
buffer_ = buffer;
}
} // namespace network
......
......@@ -14,136 +14,172 @@
#include "communicator.h"
#include "msg_queue.h"
#include "tcp_socket.h"
#include "common.h"
namespace dgl {
namespace network {
using dgl::network::MessageQueue;
using dgl::network::TCPSocket;
using dgl::network::Sender;
using dgl::network::Receiver;
static int kMaxTryCount = 1024; // maximal connection: 1024
static int kTimeOut = 10; // 10 minutes for socket timeout
static int kMaxConnection = 1024; // maximal connection: 1024
/*!
* \breif Networking address
*/
struct Addr {
std::string ip_;
int port_;
struct IPAddr {
std::string ip;
int port;
};
/*!
* \brief Network Sender for DGL distributed training.
* \brief SocketSender for DGL distributed training.
*
* Sender is an abstract class that defines a set of APIs for sending
* binary data over network. It can be implemented by different underlying
* networking libraries such TCP socket and ZMQ. One Sender can connect to
* multiple receivers, and it can send data to specified receiver via receiver's ID.
* SocketSender is the communicator implemented by tcp socket.
*/
class SocketSender : public Sender {
public:
/*!
* \brief Add receiver address and it's ID to the namebook
* \param ip receviver's IP address
* \param port receiver's port
* \brief Sender constructor
* \param queue_size size of message queue
*/
explicit SocketSender(int64_t queue_size) : Sender(queue_size) {}
/*!
* \brief Add receiver's address and ID to the sender's namebook
* \param addr Networking address, e.g., 'socket://127.0.0.1:50091', 'mpi://0'
* \param id receiver's ID
*
* AddReceiver() is not thread-safe and only one thread can invoke this API.
*/
void AddReceiver(const char* ip, int port, int recv_id);
void AddReceiver(const char* addr, int recv_id);
/*!
* \brief Connect with all the Receivers
* \return True for sucess and False for fail
* \return True for success and False for fail
*
* Connect() is not thread-safe and only one thread can invoke this API.
*/
bool Connect();
/*!
* \brief Send data to specified Receiver
* \param data data buffer for sending
* \param size data size for sending
* \brief Send data to specified Receiver. Actually pushing message to message queue.
* \param msg data message
* \param recv_id receiver's ID
* \return bytes we sent
* > 0 : bytes we sent
* - 1 : error
* \return Status code
*
* (1) The send is non-blocking. There is no guarantee that the message has been
* physically sent out when the function returns.
* (2) The communicator will assume the responsibility of the given message.
* (3) The API is multi-thread safe.
* (4) Messages sent to the same receiver are guaranteed to be received in the same order.
* There is no guarantee for messages sent to different receivers.
*/
int64_t Send(const char* data, int64_t size, int recv_id);
STATUS Send(Message msg, int recv_id);
/*!
* \brief Finalize Sender
* \brief Finalize SocketSender
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
void Finalize();
/*!
* \brief Get data buffer
* \return buffer pointer
* \brief Communicator type: 'socket'
*/
char* GetBuffer();
inline std::string Type() const { return std::string("socket"); }
private:
/*!
* \brief Set data buffer
*/
void SetBuffer(char* buffer);
* \brief socket for each connection of receiver
*/
std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>> sockets_;
private:
/*!
* \brief socket map
* \brief receivers' address
*/
std::unordered_map<int, TCPSocket*> socket_map_;
std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
/*!
* \brief receiver address map
* \brief message queue for each socket connection
*/
std::unordered_map<int, Addr> receiver_addr_map_;
std::unordered_map<int /* receiver ID */, std::shared_ptr<MessageQueue>> msg_queue_;
/*!
* \brief data buffer
* \brief Independent thread for each socket connection
*/
char* buffer_ = nullptr;
std::unordered_map<int /* receiver ID */, std::shared_ptr<std::thread>> threads_;
/*!
* \brief Send-loop for each socket in per-thread
* \param socket TCPSocket for current connection
* \param queue message_queue for current connection
*
* Note that, the SendLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue.
*/
static void SendLoop(TCPSocket* socket, MessageQueue* queue);
};
/*!
* \brief Network Receiver for DGL distributed training.
* \brief SocketReceiver for DGL distributed training.
*
* Receiver is an abstract class that defines a set of APIs for receiving binary
* data over network. It can be implemented by different underlying networking libraries
* such TCP socket and ZMQ. One Receiver can connect with multiple Senders, and it can receive
* data from these Senders concurrently via multi-threading and message queue.
* SocketReceiver is the communicator implemented by tcp socket.
*/
class SocketReceiver : public Receiver {
public:
/*!
* \brief Wait all of the Senders to connect
* \param ip Receiver's IP address
* \param port Receiver's port
* \param num_sender total number of Senders
* \param queue_size size of message queue
* \return True for sucess and False for fail
* \brief Receiver constructor
* \param queue_size size of message queue.
*/
bool Wait(const char* ip, int port, int num_sender, int queue_size);
explicit SocketReceiver(int64_t queue_size) : Receiver(queue_size) {}
/*!
* \brief Recv data from Sender (copy data from message queue)
* \param dest data buffer of destination
* \param max_size maximul size of data buffer
* \return bytes we received
* > 0 : bytes we received
* - 1 : error
* \brief Wait for all the Senders to connect
* \param addr Networking address, e.g., 'socket://127.0.0.1:50051', 'mpi://0'
* \param num_sender total number of Senders
* \return True for success and False for fail
*
* Wait() is not thread-safe and only one thread can invoke this API.
*/
int64_t Recv(char* dest, int64_t max_size);
/*!
* \brief Finalize Receiver
bool Wait(const char* addr, int num_sender);
/*!
* \brief Recv data from Sender. Actually removing data from msg_queue.
* \param msg pointer of data message
* \param send_id which sender current msg comes from
* \return Status code
*
* (1) The Recv() API is blocking, which will not
* return until getting data from message queue.
* (2) The Recv() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/
void Finalize();
STATUS Recv(Message* msg, int* send_id);
/*!
* \brief Recv data from a specified Sender. Actually removing data from msg_queue.
* \param msg pointer of data message
* \param send_id sender's ID
* \return Status code
*
* (1) The RecvFrom() API is blocking, which will not
* return until getting data from message queue.
* (2) The RecvFrom() API is thread-safe.
* (3) Memory allocated by communicator but will not own it after the function returns.
*/
STATUS RecvFrom(Message* msg, int send_id);
/*!
* \brief Get data buffer
* \return buffer pointer
* \brief Finalize SocketReceiver
*
* Finalize() is not thread-safe and only one thread can invoke this API.
*/
char* GetBuffer();
void Finalize();
/*!
* \brief Set data buffer
* \brief Communicator type: 'socket'
*/
void SetBuffer(char* buffer);
inline std::string Type() const { return std::string("socket"); }
private:
/*!
......@@ -152,37 +188,34 @@ class SocketReceiver : public Receiver {
int num_sender_;
/*!
* \brief maximal size of message queue
*/
int64_t queue_size_;
/*!
* \brief socket list
* \brief server socket for listening connections
*/
std::vector<TCPSocket*> socket_;
TCPSocket* server_socket_;
/*!
* \brief Thread pool for socket connection
* \brief socket for each client connections
*/
std::vector<std::thread*> thread_;
std::unordered_map<int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>> sockets_;
/*!
* \brief Message queue for communicator
* \brief Message queue for each socket connection
*/
MessageQueue* queue_;
std::unordered_map<int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>> msg_queue_;
/*!
* \brief data buffer
* \brief Independent thead for each socket connection
*/
char* buffer_ = nullptr;
std::unordered_map<int /* Sender (virtual) ID */, std::shared_ptr<std::thread>> threads_;
/*!
* \brief Process received message in independent threads
* \param socket new accpeted socket
* \brief Recv-loop for each socket in per-thread
* \param socket client socket
* \param queue message queue
* \param id producer_id
*
* Note that, the RecvLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue.
*/
static void MsgHandler(TCPSocket* socket, MessageQueue* queue, int id);
static void RecvLoop(TCPSocket* socket, MessageQueue* queue);
};
} // namespace network
......
/*!
* Copyright (c) 2019 by Contributors
* \file msg_queue.cc
* \brief Message queue for DGL distributed training.
* \file graph_index_test.cc
* \brief Test GraphIndex
*/
#include <gtest/gtest.h>
#include <dgl/graph.h>
......
......@@ -5,39 +5,100 @@
*/
#include <gtest/gtest.h>
#include <string>
#include <thread>
#include <vector>
#include "../src/graph/network/msg_queue.h"
using std::string;
using dgl::network::Message;
using dgl::network::MessageQueue;
TEST(MessageQueueTest, AddRemove) {
MessageQueue queue(5, 1); // size:5, num_of_producer:1
char buff[10];
queue.Add("111", 3);
queue.Add("22", 2);
EXPECT_EQ(0, queue.Add("xxxx", 4, false)); // non-blocking add
queue.Remove(buff, 3);
EXPECT_EQ(string(buff, 3), string("111"));
queue.Remove(buff, 2);
EXPECT_EQ(string(buff, 2), string("22"));
queue.Add("33333", 5);
queue.Remove(buff, 5);
EXPECT_EQ(string(buff, 5), string("33333"));
EXPECT_EQ(0, queue.Remove(buff, 10, false)); // non-blocking remove
EXPECT_EQ(queue.Add("666666", 6), -1); // exceed buffer size
queue.Add("55555", 5);
EXPECT_EQ(queue.Remove(buff, 3), -1); // message too long
// msg 1
std::string str_1("111");
Message msg_1 = {const_cast<char*>(str_1.data()), 3};
EXPECT_EQ(queue.Add(msg_1), ADD_SUCCESS);
// msg 2
std::string str_2("22");
Message msg_2 = {const_cast<char*>(str_2.data()), 2};
EXPECT_EQ(queue.Add(msg_2), ADD_SUCCESS);
// msg 3
std::string str_3("xxxx");
Message msg_3 = {const_cast<char*>(str_3.data()), 4};
EXPECT_EQ(queue.Add(msg_3, false), QUEUE_FULL);
// msg 4
Message msg_4;
EXPECT_EQ(queue.Remove(&msg_4), REMOVE_SUCCESS);
EXPECT_EQ(string(msg_4.data, msg_4.size), string("111"));
// msg 5
Message msg_5;
EXPECT_EQ(queue.Remove(&msg_5), REMOVE_SUCCESS);
EXPECT_EQ(string(msg_5.data, msg_5.size), string("22"));
// msg 6
std::string str_6("33333");
Message msg_6 = {const_cast<char*>(str_6.data()), 5};
EXPECT_EQ(queue.Add(msg_6), ADD_SUCCESS);
// msg 7
Message msg_7;
EXPECT_EQ(queue.Remove(&msg_7), REMOVE_SUCCESS);
EXPECT_EQ(string(msg_7.data, msg_7.size), string("33333"));
// msg 8
Message msg_8;
EXPECT_EQ(queue.Remove(&msg_8, false), QUEUE_EMPTY); // non-blocking remove
// msg 9
std::string str_9("666666");
Message msg_9 = {const_cast<char*>(str_9.data()), 6};
EXPECT_EQ(queue.Add(msg_9), MSG_GT_SIZE); // exceed queue size
// msg 10
std::string str_10("55555");
Message msg_10 = {const_cast<char*>(str_10.data()), 5};
EXPECT_EQ(queue.Add(msg_10), ADD_SUCCESS);
// msg 11
Message msg_11;
EXPECT_EQ(queue.Remove(&msg_11), REMOVE_SUCCESS);
}
TEST(MessageQueueTest, EmptyAndNoMoreAdd) {
MessageQueue queue(5, 2); // size:5, num_of_producer:2
char buff[10];
EXPECT_EQ(queue.EmptyAndNoMoreAdd(), false);
queue.Signal(1);
queue.Signal(1);
EXPECT_EQ(queue.Empty(), true);
queue.SignalFinished(1);
queue.SignalFinished(1);
EXPECT_EQ(queue.EmptyAndNoMoreAdd(), false);
queue.Signal(2);
queue.SignalFinished(2);
EXPECT_EQ(queue.EmptyAndNoMoreAdd(), true);
}
const int kNumOfProducer = 100;
const int kNumOfMessage = 100;
std::string str_apple("apple");
void start_add(MessageQueue* queue, int id) {
for (int i = 0; i < kNumOfMessage; ++i) {
Message msg = {const_cast<char*>(str_apple.data()), 5};
EXPECT_EQ(queue->Add(msg), ADD_SUCCESS);
}
queue->SignalFinished(id);
}
TEST(MessageQueueTest, MultiThread) {
MessageQueue queue(100000, kNumOfProducer);
EXPECT_EQ(queue.EmptyAndNoMoreAdd(), false);
EXPECT_EQ(queue.Empty(), true);
std::vector<std::thread*> thread_pool;
for (int i = 0; i < kNumOfProducer; ++i) {
thread_pool.push_back(new std::thread(start_add, &queue, i));
}
for (int i = 0; i < kNumOfProducer*kNumOfMessage; ++i) {
Message msg;
EXPECT_EQ(queue.Remove(&msg), REMOVE_SUCCESS);
EXPECT_EQ(string(msg.data, msg.size), string("apple"));
}
for (int i = 0; i < kNumOfProducer; ++i) {
thread_pool[i]->join();
}
EXPECT_EQ(queue.EmptyAndNoMoreAdd(), true);
EXPECT_EQ(queue.Remove(buff, 5), 0);
}
\ No newline at end of file
/*!
* Copyright (c) 2019 by Contributors
* \file msg_queue.cc
* \brief Message queue for DGL distributed training.
* \file socket_communicator_test.cc
* \brief Test SocketCommunicator
*/
#include <gtest/gtest.h>
#include <string.h>
#include <string>
#include <thread>
#include <vector>
#include "../src/graph/network/msg_queue.h"
#include "../src/graph/network/socket_communicator.h"
using std::string;
using dgl::network::SocketSender;
using dgl::network::SocketReceiver;
using dgl::network::Message;
using dgl::network::DefaultMessageDeleter;
void start_client();
bool start_server();
const int64_t kQueueSize = 500 * 1024;
#ifndef WIN32
#include <unistd.h>
const int kNumSender = 3;
const int kNumReceiver = 3;
const int kNumMessage = 10;
const char* ip_addr[] = {
"socket://127.0.0.1:50091",
"socket://127.0.0.1:50092",
"socket://127.0.0.1:50093"
};
static void start_client();
static void start_server(int id);
TEST(SocketCommunicatorTest, SendAndRecv) {
std::thread client_thread(start_client);
start_server();
client_thread.join();
// start 10 client
std::vector<std::thread*> client_thread;
for (int i = 0; i < kNumSender; ++i) {
client_thread.push_back(new std::thread(start_client));
}
// start 10 server
std::vector<std::thread*> server_thread;
for (int i = 0; i < kNumReceiver; ++i) {
server_thread.push_back(new std::thread(start_server, i));
}
for (int i = 0; i < kNumSender; ++i) {
client_thread[i]->join();
}
for (int i = 0; i < kNumReceiver; ++i) {
server_thread[i]->join();
}
}
#else // WIN32
void start_client() {
sleep(2); // wait server start
SocketSender sender(kQueueSize);
for (int i = 0; i < kNumReceiver; ++i) {
sender.AddReceiver(ip_addr[i], i);
}
sender.Connect();
for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < kNumReceiver; ++n) {
char* str_data = new char[9];
memcpy(str_data, "123456789", 9);
Message msg = {str_data, 9};
msg.deallocator = DefaultMessageDeleter;
EXPECT_EQ(sender.Send(msg, n), ADD_SUCCESS);
}
}
for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < kNumReceiver; ++n) {
char* str_data = new char[9];
memcpy(str_data, "123456789", 9);
Message msg = {str_data, 9};
msg.deallocator = DefaultMessageDeleter;
EXPECT_EQ(sender.Send(msg, n), ADD_SUCCESS);
}
}
sender.Finalize();
}
void start_server(int id) {
SocketReceiver receiver(kQueueSize);
receiver.Wait(ip_addr[id], kNumSender);
for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < kNumSender; ++n) {
Message msg;
EXPECT_EQ(receiver.RecvFrom(&msg, n), REMOVE_SUCCESS);
EXPECT_EQ(string(msg.data, msg.size), string("123456789"));
msg.deallocator(&msg);
}
}
for (int n = 0; n < kNumSender*kNumMessage; ++n) {
Message msg;
int recv_id;
EXPECT_EQ(receiver.Recv(&msg, &recv_id), REMOVE_SUCCESS);
EXPECT_EQ(string(msg.data, msg.size), string("123456789"));
msg.deallocator(&msg);
}
receiver.Finalize();
}
#else
#include <windows.h>
#include <winsock2.h>
......@@ -37,6 +115,9 @@ void sleep(int seconds) {
Sleep(seconds * 1000);
}
static void start_client();
static bool start_server();
DWORD WINAPI _ClientThreadFunc(LPVOID param) {
start_client();
return 0;
......@@ -70,24 +151,26 @@ TEST(SocketCommunicatorTest, SendAndRecv) {
::WSACleanup();
}
#endif // WIN32
void start_client() {
const char * msg = "123456789";
static void start_client() {
sleep(1);
SocketSender sender;
sender.AddReceiver("127.0.0.1", 2049, 0);
SocketSender sender(kQueueSize);
sender.AddReceiver("socket://127.0.0.1:8001", 0);
sender.Connect();
sender.Send(msg, 9, 0);
char* str_data = new char[9];
memcpy(str_data, "123456789", 9);
Message msg = {str_data, 9};
msg.deallocator = DefaultMessageDeleter;
sender.Send(msg, 0);
sender.Finalize();
}
bool start_server() {
char serbuff[10];
memset(serbuff, '\0', 10);
SocketReceiver receiver;
receiver.Wait("127.0.0.1", 2049, 1, 500 * 1024);
receiver.Recv(serbuff, 9);
static bool start_server() {
SocketReceiver receiver(kQueueSize);
receiver.Wait("socket://127.0.0.1:8001", 1);
Message msg;
EXPECT_EQ(receiver.RecvFrom(&msg, 0), REMOVE_SUCCESS);
receiver.Finalize();
return string("123456789") == string(serbuff);
return string("123456789") == string(msg.data, msg.size);
}
#endif
/*!
* Copyright (c) 2019 by Contributors
* \file string_test.cc
* \brief Test String Common
*/
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "../src/graph/network/common.h"
using dgl::network::SplitStringUsing;
using dgl::network::StringPrintf;
using dgl::network::SStringPrintf;
using dgl::network::StringAppendF;
TEST(SplitStringTest, SplitStringUsingCompoundDelim) {
std::string full(" apple \torange ");
std::vector<std::string> subs;
SplitStringUsing(full, " \t", &subs);
EXPECT_EQ(subs.size(), 2);
EXPECT_EQ(subs[0], std::string("apple"));
EXPECT_EQ(subs[1], std::string("orange"));
}
TEST(SplitStringTest, testSplitStringUsingSingleDelim) {
std::string full(" apple orange ");
std::vector<std::string> subs;
SplitStringUsing(full, " ", &subs);
EXPECT_EQ(subs.size(), 2);
EXPECT_EQ(subs[0], std::string("apple"));
EXPECT_EQ(subs[1], std::string("orange"));
}
TEST(SplitStringTest, testSplitingNoDelimString) {
std::string full("apple");
std::vector<std::string> subs;
SplitStringUsing(full, " ", &subs);
EXPECT_EQ(subs.size(), 1);
EXPECT_EQ(subs[0], std::string("apple"));
}
TEST(StringPrintf, normal) {
using std::string;
EXPECT_EQ(StringPrintf("%d", 1), string("1"));
string target;
SStringPrintf(&target, "%d", 1);
EXPECT_EQ(target, string("1"));
StringAppendF(&target, "%d", 2);
EXPECT_EQ(target, string("12"));
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment