Unverified Commit 5cf48fc6 authored by Jingcheng Yu's avatar Jingcheng Yu Committed by GitHub
Browse files

[Feature] Implement one thread multiple socket (#3200)


Co-authored-by: default avatarJingchengYu94 <jingchengyu94@gmail.com>
parent 179d6aab
...@@ -34,6 +34,7 @@ dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF) ...@@ -34,6 +34,7 @@ dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
dgl_option(LIBCXX_ENABLE_PARALLEL_ALGORITHMS "Enable the parallel algorithms library. This requires the PSTL to be available." OFF) dgl_option(LIBCXX_ENABLE_PARALLEL_ALGORITHMS "Enable the parallel algorithms library. This requires the PSTL to be available." OFF)
dgl_option(USE_S3 "Build with S3 support" OFF) dgl_option(USE_S3 "Build with S3 support" OFF)
dgl_option(USE_HDFS "Build with HDFS support" OFF) # Set env HADOOP_HDFS_HOME if needed dgl_option(USE_HDFS "Build with HDFS support" OFF) # Set env HADOOP_HDFS_HOME if needed
dgl_option(USE_EPOLL "Build with epoll for socket communicator" OFF)
# Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG # Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG
if (NOT MSVC) if (NOT MSVC)
...@@ -120,6 +121,14 @@ if(USE_AVX) ...@@ -120,6 +121,14 @@ if(USE_AVX)
endif(USE_LIBXSMM) endif(USE_LIBXSMM)
endif(USE_AVX) endif(USE_AVX)
if (USE_EPOLL)
check_include_file("sys/epoll.h" USE_EPOLL)
if (USE_EPOLL)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_EPOLL")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_EPOLL")
endif()
endif ()
# Build with fp16 to support mixed precision training. # Build with fp16 to support mixed precision training.
if(USE_FP16) if(USE_FP16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_FP16") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_FP16")
......
"""RPC components. They are typically functions or utilities used by both """RPC components. They are typically functions or utilities used by both
server and clients.""" server and clients."""
import os
import abc import abc
import pickle import pickle
import random import random
...@@ -111,7 +112,8 @@ def create_sender(max_queue_size, net_type): ...@@ -111,7 +112,8 @@ def create_sender(max_queue_size, net_type):
net_type : str net_type : str
Networking type. Current options are: 'socket'. Networking type. Current options are: 'socket'.
""" """
_CAPI_DGLRPCCreateSender(int(max_queue_size), net_type) max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0'))
_CAPI_DGLRPCCreateSender(int(max_queue_size), net_type, max_thread_count)
def create_receiver(max_queue_size, net_type): def create_receiver(max_queue_size, net_type):
"""Create rpc receiver of this process. """Create rpc receiver of this process.
...@@ -123,7 +125,8 @@ def create_receiver(max_queue_size, net_type): ...@@ -123,7 +125,8 @@ def create_receiver(max_queue_size, net_type):
net_type : str net_type : str
Networking type. Current options are: 'socket'. Networking type. Current options are: 'socket'.
""" """
_CAPI_DGLRPCCreateReceiver(int(max_queue_size), net_type) max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0'))
_CAPI_DGLRPCCreateReceiver(int(max_queue_size), net_type, max_thread_count)
def finalize_sender(): def finalize_sender():
"""Finalize rpc sender of this process. """Finalize rpc sender of this process.
......
...@@ -206,7 +206,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate") ...@@ -206,7 +206,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
int64_t msg_queue_size = args[1]; int64_t msg_queue_size = args[1];
network::Sender* sender = nullptr; network::Sender* sender = nullptr;
if (type == "socket") { if (type == "socket") {
sender = new network::SocketSender(msg_queue_size); sender = new network::SocketSender(msg_queue_size, 0);
} else { } else {
LOG(FATAL) << "Unknown communicator type: " << type; LOG(FATAL) << "Unknown communicator type: " << type;
} }
...@@ -220,7 +220,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate") ...@@ -220,7 +220,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
int64_t msg_queue_size = args[1]; int64_t msg_queue_size = args[1];
network::Receiver* receiver = nullptr; network::Receiver* receiver = nullptr;
if (type == "socket") { if (type == "socket") {
receiver = new network::SocketReceiver(msg_queue_size); receiver = new network::SocketReceiver(msg_queue_size, 0);
} else { } else {
LOG(FATAL) << "Unknown communicator type: " << type; LOG(FATAL) << "Unknown communicator type: " << type;
} }
......
...@@ -28,11 +28,14 @@ class Sender { ...@@ -28,11 +28,14 @@ class Sender {
/*! /*!
* \brief Sender constructor * \brief Sender constructor
* \param queue_size size (bytes) of message queue. * \param queue_size size (bytes) of message queue.
* \param max_thread_count size of thread pool. 0 for no limit
* Note that, the queue_size parameter is optional. * Note that, the queue_size parameter is optional.
*/ */
explicit Sender(int64_t queue_size = 0) { explicit Sender(int64_t queue_size = 0, int max_thread_count = 0) {
CHECK_GE(queue_size, 0); CHECK_GE(queue_size, 0);
CHECK_GE(max_thread_count, 0);
queue_size_ = queue_size; queue_size_ = queue_size;
max_thread_count_ = max_thread_count;
} }
virtual ~Sender() {} virtual ~Sender() {}
...@@ -86,6 +89,10 @@ class Sender { ...@@ -86,6 +89,10 @@ class Sender {
* \brief Size of message queue * \brief Size of message queue
*/ */
int64_t queue_size_; int64_t queue_size_;
/*!
* \brief Size of thread pool. 0 for no limit
*/
int max_thread_count_;
}; };
/*! /*!
...@@ -101,13 +108,16 @@ class Receiver { ...@@ -101,13 +108,16 @@ class Receiver {
/*! /*!
* \brief Receiver constructor * \brief Receiver constructor
* \param queue_size size of message queue. * \param queue_size size of message queue.
* \param max_thread_count size of thread pool. 0 for no limit
* Note that, the queue_size parameter is optional. * Note that, the queue_size parameter is optional.
*/ */
explicit Receiver(int64_t queue_size = 0) { explicit Receiver(int64_t queue_size = 0, int max_thread_count = 0) {
if (queue_size < 0) { if (queue_size < 0) {
LOG(FATAL) << "queue_size cannot be a negative number."; LOG(FATAL) << "queue_size cannot be a negative number.";
} }
CHECK_GE(max_thread_count, 0);
queue_size_ = queue_size; queue_size_ = queue_size;
max_thread_count_ = max_thread_count;
} }
virtual ~Receiver() {} virtual ~Receiver() {}
...@@ -165,6 +175,10 @@ class Receiver { ...@@ -165,6 +175,10 @@ class Receiver {
* \brief Size of message queue * \brief Size of message queue
*/ */
int64_t queue_size_; int64_t queue_size_;
/*!
* \brief Size of thread pool. 0 for no limit
*/
int max_thread_count_;
}; };
} // namespace network } // namespace network
......
...@@ -72,6 +72,7 @@ STATUS MessageQueue::Remove(Message* msg, bool is_blocking) { ...@@ -72,6 +72,7 @@ STATUS MessageQueue::Remove(Message* msg, bool is_blocking) {
queue_.pop(); queue_.pop();
msg->data = old_msg.data; msg->data = old_msg.data;
msg->size = old_msg.size; msg->size = old_msg.size;
msg->receiver_id = old_msg.receiver_id;
msg->deallocator = old_msg.deallocator; msg->deallocator = old_msg.deallocator;
free_size_ += old_msg.size; free_size_ += old_msg.size;
cond_not_full_.notify_one(); cond_not_full_.notify_one();
......
...@@ -56,6 +56,10 @@ struct Message { ...@@ -56,6 +56,10 @@ struct Message {
* \brief message size in bytes * \brief message size in bytes
*/ */
int64_t size; int64_t size;
/*!
* \brief message receiver id
*/
int receiver_id = -1;
/*! /*!
* \brief user-defined deallocator, which can be nullptr * \brief user-defined deallocator, which can be nullptr
*/ */
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "socket_communicator.h" #include "socket_communicator.h"
#include "../../c_api_common.h" #include "../../c_api_common.h"
#include "socket_pool.h"
#ifdef _WIN32 #ifdef _WIN32
#include <windows.h> #include <windows.h>
...@@ -51,15 +52,20 @@ void SocketSender::AddReceiver(const char* addr, int recv_id) { ...@@ -51,15 +52,20 @@ void SocketSender::AddReceiver(const char* addr, int recv_id) {
address.ip = ip_and_port[0]; address.ip = ip_and_port[0];
address.port = std::stoi(ip_and_port[1]); address.port = std::stoi(ip_and_port[1]);
receiver_addrs_[recv_id] = address; receiver_addrs_[recv_id] = address;
msg_queue_[recv_id] = std::make_shared<MessageQueue>(queue_size_);
} }
bool SocketSender::Connect() { bool SocketSender::Connect() {
// Create N sockets for Receiver // Create N sockets for Receiver
int receiver_count = static_cast<int>(receiver_addrs_.size());
if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) {
max_thread_count_ = receiver_count;
}
sockets_.resize(max_thread_count_);
for (const auto& r : receiver_addrs_) { for (const auto& r : receiver_addrs_) {
int ID = r.first; int receiver_id = r.first;
sockets_[ID] = std::make_shared<TCPSocket>(); int thread_id = receiver_id % max_thread_count_;
TCPSocket* client_socket = sockets_[ID].get(); sockets_[thread_id][receiver_id] = std::make_shared<TCPSocket>();
TCPSocket* client_socket = sockets_[thread_id][receiver_id].get();
bool bo = false; bool bo = false;
int try_count = 0; int try_count = 0;
const char* ip = r.second.ip.c_str(); const char* ip = r.second.ip.c_str();
...@@ -83,12 +89,17 @@ bool SocketSender::Connect() { ...@@ -83,12 +89,17 @@ bool SocketSender::Connect() {
if (bo == false) { if (bo == false) {
return bo; return bo;
} }
}
for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
msg_queue_.push_back(std::make_shared<MessageQueue>(queue_size_));
// Create a new thread for this socket connection // Create a new thread for this socket connection
threads_[ID] = std::make_shared<std::thread>( threads_.push_back(std::make_shared<std::thread>(
SendLoop, SendLoop,
client_socket, sockets_[thread_id],
msg_queue_[ID].get()); msg_queue_[thread_id]));
} }
return true; return true;
} }
...@@ -96,53 +107,48 @@ STATUS SocketSender::Send(Message msg, int recv_id) { ...@@ -96,53 +107,48 @@ STATUS SocketSender::Send(Message msg, int recv_id) {
CHECK_NOTNULL(msg.data); CHECK_NOTNULL(msg.data);
CHECK_GT(msg.size, 0); CHECK_GT(msg.size, 0);
CHECK_GE(recv_id, 0); CHECK_GE(recv_id, 0);
msg.receiver_id = recv_id;
// Add data message to message queue // Add data message to message queue
STATUS code = msg_queue_[recv_id]->Add(msg); STATUS code = msg_queue_[recv_id % max_thread_count_]->Add(msg);
return code; return code;
} }
void SocketSender::Finalize() { void SocketSender::Finalize() {
// Send a signal to tell the msg_queue to finish its job // Send a signal to tell the msg_queue to finish its job
for (auto& mq : msg_queue_) { for (int i = 0; i < max_thread_count_; ++i) {
// wait until queue is empty // wait until queue is empty
while (mq.second->Empty() == false) { auto& mq = msg_queue_[i];
while (mq->Empty() == false) {
#ifdef _WIN32 #ifdef _WIN32
// just loop // just loop
#else // !_WIN32 #else // !_WIN32
usleep(1000); usleep(1000);
#endif // _WIN32 #endif // _WIN32
} }
int ID = mq.first; // All queues have only one producer, which is main thread, so
mq.second->SignalFinished(ID); // the producerID argument here should be zero.
mq->SignalFinished(0);
} }
// Block main thread until all socket-threads finish their jobs // Block main thread until all socket-threads finish their jobs
for (auto& thread : threads_) { for (auto& thread : threads_) {
thread.second->join(); thread->join();
} }
// Clear all sockets // Clear all sockets
for (auto& socket : sockets_) { for (auto& group_sockets_ : sockets_) {
for (auto &socket : group_sockets_) {
socket.second->Close(); socket.second->Close();
} }
}
} }
void SocketSender::SendLoop(TCPSocket* socket, MessageQueue* queue) { void SendCore(Message msg, TCPSocket* socket) {
CHECK_NOTNULL(socket);
CHECK_NOTNULL(queue);
bool exit = false;
while (!exit) {
Message msg;
STATUS code = queue->Remove(&msg);
if (code == QUEUE_CLOSE) {
msg.size = 0; // send an end-signal to receiver
exit = true;
}
// First send the size // First send the size
// If exit == true, we will send zero size to reciever // If exit == true, we will send zero size to reciever
int64_t sent_bytes = 0; int64_t sent_bytes = 0;
while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) { while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - sent_bytes; int64_t max_len = sizeof(int64_t) - sent_bytes;
int64_t tmp = socket->Send( int64_t tmp = socket->Send(
reinterpret_cast<char*>(&msg.size)+sent_bytes, reinterpret_cast<char*>(&msg.size) + sent_bytes,
max_len); max_len);
CHECK_NE(tmp, -1); CHECK_NE(tmp, -1);
sent_bytes += tmp; sent_bytes += tmp;
...@@ -159,6 +165,22 @@ void SocketSender::SendLoop(TCPSocket* socket, MessageQueue* queue) { ...@@ -159,6 +165,22 @@ void SocketSender::SendLoop(TCPSocket* socket, MessageQueue* queue) {
if (msg.deallocator != nullptr) { if (msg.deallocator != nullptr) {
msg.deallocator(&msg); msg.deallocator(&msg);
} }
}
void SocketSender::SendLoop(
std::unordered_map<int, std::shared_ptr<TCPSocket>> sockets,
std::shared_ptr<MessageQueue> queue) {
for (;;) {
Message msg;
STATUS code = queue->Remove(&msg);
if (code == QUEUE_CLOSE) {
msg.size = 0; // send an end-signal to receiver
for (auto& socket : sockets) {
SendCore(msg, socket.second.get());
}
break;
}
SendCore(msg, sockets[msg.receiver_id].get());
} }
} }
...@@ -187,16 +209,20 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) { ...@@ -187,16 +209,20 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) {
int port = stoi(ip_and_port[1]); int port = stoi(ip_and_port[1]);
// Initialize message queue for each connection // Initialize message queue for each connection
num_sender_ = num_sender; num_sender_ = num_sender;
for (int i = 0; i < num_sender_; ++i) { #ifdef USE_EPOLL
msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_); if (max_thread_count_ == 0 || max_thread_count_ > num_sender_) {
max_thread_count_ = num_sender_;
} }
mq_iter_ = msg_queue_.begin(); #else
max_thread_count_ = num_sender_;
#endif
// Initialize socket and socket-thread // Initialize socket and socket-thread
server_socket_ = new TCPSocket(); server_socket_ = new TCPSocket();
// Bind socket // Bind socket
if (server_socket_->Bind(ip.c_str(), port) == false) { if (server_socket_->Bind(ip.c_str(), port) == false) {
LOG(FATAL) << "Cannot bind to " << ip << ":" << port; LOG(FATAL) << "Cannot bind to " << ip << ":" << port;
} }
// Listen // Listen
if (server_socket_->Listen(kMaxConnection) == false) { if (server_socket_->Listen(kMaxConnection) == false) {
LOG(FATAL) << "Cannot listen on " << ip << ":" << port; LOG(FATAL) << "Cannot listen on " << ip << ":" << port;
...@@ -204,27 +230,39 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) { ...@@ -204,27 +230,39 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) {
// Accept all sender sockets // Accept all sender sockets
std::string accept_ip; std::string accept_ip;
int accept_port; int accept_port;
sockets_.resize(max_thread_count_);
for (int i = 0; i < num_sender_; ++i) { for (int i = 0; i < num_sender_; ++i) {
sockets_[i] = std::make_shared<TCPSocket>(); int thread_id = i % max_thread_count_;
if (server_socket_->Accept(sockets_[i].get(), &accept_ip, &accept_port) == false) { auto socket = std::make_shared<TCPSocket>();
sockets_[thread_id][i] = socket;
msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
if (server_socket_->Accept(socket.get(), &accept_ip, &accept_port) == false) {
LOG(WARNING) << "Error on accept socket."; LOG(WARNING) << "Error on accept socket.";
return false; return false;
} }
}
mq_iter_ = msg_queue_.begin();
for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
// create new thread for each socket // create new thread for each socket
threads_[i] = std::make_shared<std::thread>( threads_.push_back(std::make_shared<std::thread>(
RecvLoop, RecvLoop,
sockets_[i].get(), sockets_[thread_id],
msg_queue_[i].get()); msg_queue_,
&queue_sem_));
} }
return true; return true;
} }
STATUS SocketReceiver::Recv(Message* msg, int* send_id) { STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
// loop until get a message // queue_sem_ is a semaphore indicating how many elements in multiple
// message queues.
// When calling queue_sem_.Wait(), this Recv will be suspended until
// queue_sem_ > 0, decrease queue_sem_ by 1, then start to fetch a message.
queue_sem_.Wait();
for (;;) { for (;;) {
for (; mq_iter_ != msg_queue_.end(); ++mq_iter_) { for (; mq_iter_ != msg_queue_.end(); ++mq_iter_) {
// We use non-block remove here
STATUS code = mq_iter_->second->Remove(msg, false); STATUS code = mq_iter_->second->Remove(msg, false);
if (code == QUEUE_EMPTY) { if (code == QUEUE_EMPTY) {
continue; // jump to the next queue continue; // jump to the next queue
...@@ -240,6 +278,7 @@ STATUS SocketReceiver::Recv(Message* msg, int* send_id) { ...@@ -240,6 +278,7 @@ STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
STATUS SocketReceiver::RecvFrom(Message* msg, int send_id) { STATUS SocketReceiver::RecvFrom(Message* msg, int send_id) {
// Get message from specified message queue // Get message from specified message queue
queue_sem_.Wait();
STATUS code = msg_queue_[send_id]->Remove(msg); STATUS code = msg_queue_[send_id]->Remove(msg);
return code; return code;
} }
...@@ -255,47 +294,93 @@ void SocketReceiver::Finalize() { ...@@ -255,47 +294,93 @@ void SocketReceiver::Finalize() {
usleep(1000); usleep(1000);
#endif // _WIN32 #endif // _WIN32
} }
int ID = mq.first; mq.second->SignalFinished(mq.first);
mq.second->SignalFinished(ID);
} }
// Block main thread until all socket-threads finish their jobs // Block main thread until all socket-threads finish their jobs
for (auto& thread : threads_) { for (auto& thread : threads_) {
thread.second->join(); thread->join();
} }
// Clear all sockets // Clear all sockets
for (auto& socket : sockets_) { for (auto& group_sockets : sockets_) {
for (auto& socket : group_sockets) {
socket.second->Close(); socket.second->Close();
} }
}
server_socket_->Close(); server_socket_->Close();
delete server_socket_; delete server_socket_;
} }
void SocketReceiver::RecvLoop(TCPSocket* socket, MessageQueue* queue) { int64_t RecvDataSize(TCPSocket* socket) {
CHECK_NOTNULL(socket);
CHECK_NOTNULL(queue);
for (;;) {
// If main thread had finished its job
if (queue->EmptyAndNoMoreAdd()) {
return; // exit loop thread
}
// First recv the size
int64_t received_bytes = 0; int64_t received_bytes = 0;
int64_t data_size = 0; int64_t data_size = 0;
while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) { while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - received_bytes; int64_t max_len = sizeof(int64_t) - received_bytes;
int64_t tmp = socket->Receive( int64_t tmp = socket->Receive(
reinterpret_cast<char*>(&data_size)+received_bytes, reinterpret_cast<char*>(&data_size) + received_bytes,
max_len); max_len);
CHECK_NE(tmp, -1); if (tmp == -1) {
if (received_bytes > 0) {
// We want to finish reading full data_size
continue;
}
return -1;
}
received_bytes += tmp; received_bytes += tmp;
} }
if (data_size < 0) { return data_size;
LOG(FATAL) << "Recv data error (data_size: " << data_size << ")"; }
} else if (data_size == 0) {
// This is an end-signal sent by client void RecvData(TCPSocket* socket, char* buffer, const int64_t &data_size,
int64_t *received_bytes) {
while (*received_bytes < data_size) {
int64_t max_len = data_size - *received_bytes;
int64_t tmp = socket->Receive(buffer + *received_bytes, max_len);
if (tmp == -1) {
// Socket not ready, no more data to read
return; return;
} else { }
char* buffer = nullptr; *received_bytes += tmp;
}
}
void SocketReceiver::RecvLoop(
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<TCPSocket>> sockets,
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<MessageQueue>> queues,
runtime::Semaphore *queue_sem) {
std::unordered_map<int, std::unique_ptr<RecvContext>> recv_contexts;
SocketPool socket_pool;
for (auto& socket : sockets) {
auto &sender_id = socket.first;
socket_pool.AddSocket(socket.second, sender_id);
recv_contexts[sender_id] = std::unique_ptr<RecvContext>(new RecvContext());
}
// Main loop to receive messages
for (;;) {
int sender_id;
// Get active socket using epoll
std::shared_ptr<TCPSocket> socket = socket_pool.GetActiveSocket(&sender_id);
if (queues[sender_id]->EmptyAndNoMoreAdd()) {
// This sender has already stopped
if (socket_pool.RemoveSocket(socket) == 0) {
return;
}
continue;
}
// Nonblocking socket might be interrupted at any point. So we need to
// store the partially received data
std::unique_ptr<RecvContext> &ctx = recv_contexts[sender_id];
int64_t &data_size = ctx->data_size;
int64_t &received_bytes = ctx->received_bytes;
char*& buffer = ctx->buffer;
if (data_size == -1) {
// This is a new message, so receive the data size first
data_size = RecvDataSize(socket.get());
if (data_size > 0) {
try { try {
buffer = new char[data_size]; buffer = new char[data_size];
} catch(const std::bad_alloc&) { } catch(const std::bad_alloc&) {
...@@ -303,17 +388,28 @@ void SocketReceiver::RecvLoop(TCPSocket* socket, MessageQueue* queue) { ...@@ -303,17 +388,28 @@ void SocketReceiver::RecvLoop(TCPSocket* socket, MessageQueue* queue) {
<< "(message size: " << data_size << ")"; << "(message size: " << data_size << ")";
} }
received_bytes = 0; received_bytes = 0;
while (received_bytes < data_size) { } else if (data_size == 0) {
int64_t max_len = data_size - received_bytes; // Received stop signal
int64_t tmp = socket->Receive(buffer+received_bytes, max_len); if (socket_pool.RemoveSocket(socket) == 0) {
CHECK_NE(tmp, -1); return;
received_bytes += tmp; }
}
} }
RecvData(socket.get(), buffer, data_size, &received_bytes);
if (received_bytes >= data_size) {
// Full data received, create Message and push to queue
Message msg; Message msg;
msg.data = buffer; msg.data = buffer;
msg.size = data_size; msg.size = data_size;
msg.deallocator = DefaultMessageDeleter; msg.deallocator = DefaultMessageDeleter;
queue->Add(msg); queues[sender_id]->Add(msg);
// Reset recv context
data_size = -1;
// Signal queue semaphore
queue_sem->Post();
} }
} }
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <unordered_map> #include <unordered_map>
#include <memory> #include <memory>
#include "../../runtime/semaphore_wrapper.h"
#include "communicator.h" #include "communicator.h"
#include "msg_queue.h" #include "msg_queue.h"
#include "tcp_socket.h" #include "tcp_socket.h"
...@@ -42,8 +43,10 @@ class SocketSender : public Sender { ...@@ -42,8 +43,10 @@ class SocketSender : public Sender {
/*! /*!
* \brief Sender constructor * \brief Sender constructor
* \param queue_size size of message queue * \param queue_size size of message queue
* \param max_thread_count size of thread pool. 0 for no limit
*/ */
explicit SocketSender(int64_t queue_size) : Sender(queue_size) {} SocketSender(int64_t queue_size, int max_thread_count)
: Sender(queue_size, max_thread_count) {}
/*! /*!
* \brief Add receiver's address and ID to the sender's namebook * \brief Add receiver's address and ID to the sender's namebook
...@@ -93,7 +96,8 @@ class SocketSender : public Sender { ...@@ -93,7 +96,8 @@ class SocketSender : public Sender {
/*! /*!
* \brief socket for each connection of receiver * \brief socket for each connection of receiver
*/ */
std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>> sockets_; std::vector<std::unordered_map<int /* receiver ID */,
std::shared_ptr<TCPSocket>>> sockets_;
/*! /*!
* \brief receivers' address * \brief receivers' address
...@@ -101,24 +105,27 @@ class SocketSender : public Sender { ...@@ -101,24 +105,27 @@ class SocketSender : public Sender {
std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_; std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
/*! /*!
* \brief message queue for each socket connection * \brief message queue for each thread
*/ */
std::unordered_map<int /* receiver ID */, std::shared_ptr<MessageQueue>> msg_queue_; std::vector<std::shared_ptr<MessageQueue>> msg_queue_;
/*! /*!
* \brief Independent thread for each socket connection * \brief Independent thread
*/ */
std::unordered_map<int /* receiver ID */, std::shared_ptr<std::thread>> threads_; std::vector<std::shared_ptr<std::thread>> threads_;
/*! /*!
* \brief Send-loop for each socket in per-thread * \brief Send-loop for each thread
* \param socket TCPSocket for current connection * \param sockets TCPSockets for current thread
* \param queue message_queue for current connection * \param queue message_queue for current thread
* *
* Note that, the SendLoop will finish its loop-job and exit thread * Note that, the SendLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue. * when the main thread invokes Signal() API on the message queue.
*/ */
static void SendLoop(TCPSocket* socket, MessageQueue* queue); static void SendLoop(
std::unordered_map<int /* Receiver (virtual) ID */,
std::shared_ptr<TCPSocket>> sockets,
std::shared_ptr<MessageQueue> queue);
}; };
/*! /*!
...@@ -131,8 +138,10 @@ class SocketReceiver : public Receiver { ...@@ -131,8 +138,10 @@ class SocketReceiver : public Receiver {
/*! /*!
* \brief Receiver constructor * \brief Receiver constructor
* \param queue_size size of message queue. * \param queue_size size of message queue.
* \param max_thread_count size of thread pool. 0 for no limit
*/ */
explicit SocketReceiver(int64_t queue_size) : Receiver(queue_size) {} SocketReceiver(int64_t queue_size, int max_thread_count)
: Receiver(queue_size, max_thread_count) {}
/*! /*!
* \brief Wait for all the Senders to connect * \brief Wait for all the Senders to connect
...@@ -183,6 +192,11 @@ class SocketReceiver : public Receiver { ...@@ -183,6 +192,11 @@ class SocketReceiver : public Receiver {
inline std::string Type() const { return std::string("socket"); } inline std::string Type() const { return std::string("socket"); }
private: private:
struct RecvContext {
int64_t data_size = -1;
int64_t received_bytes = 0;
char *buffer = nullptr;
};
/*! /*!
* \brief number of sender * \brief number of sender
*/ */
...@@ -196,28 +210,41 @@ class SocketReceiver : public Receiver { ...@@ -196,28 +210,41 @@ class SocketReceiver : public Receiver {
/*! /*!
* \brief socket for each client connections * \brief socket for each client connections
*/ */
std::unordered_map<int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>> sockets_; std::vector<std::unordered_map<int /* Sender (virutal) ID */,
std::shared_ptr<TCPSocket>>> sockets_;
/*! /*!
* \brief Message queue for each socket connection * \brief Message queue for each socket connection
*/ */
std::unordered_map<int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>> msg_queue_; std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<MessageQueue>> msg_queue_;
std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_; std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_;
/*! /*!
* \brief Independent thead for each socket connection * \brief Independent thead
*/ */
std::unordered_map<int /* Sender (virtual) ID */, std::shared_ptr<std::thread>> threads_; std::vector<std::shared_ptr<std::thread>> threads_;
/*! /*!
* \brief Recv-loop for each socket in per-thread * \brief queue_sem_ semphore to indicate number of messages in multiple
* \param socket client socket * message queues to prevent busy wait of Recv
* \param queue message queue */
runtime::Semaphore queue_sem_;
/*!
* \brief Recv-loop for each thread
* \param sockets client sockets of current thread
* \param queue message queues of current thread
* *
* Note that, the RecvLoop will finish its loop-job and exit thread * Note that, the RecvLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue. * when the main thread invokes Signal() API on the message queue.
*/ */
static void RecvLoop(TCPSocket* socket, MessageQueue* queue); static void RecvLoop(
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<TCPSocket>> sockets,
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<MessageQueue>> queues,
runtime::Semaphore *queue_sem);
}; };
} // namespace network } // namespace network
......
/*!
* Copyright (c) 2021 by Contributors
* \file socket_pool.cc
* \brief Socket pool of nonblocking sockets for DGL distributed training.
*/
#include "socket_pool.h"
#include <dmlc/logging.h>
#include "tcp_socket.h"
#ifdef USE_EPOLL
#include <sys/epoll.h>
#endif
namespace dgl {
namespace network {
SocketPool::SocketPool() {
#ifdef USE_EPOLL
epfd_ = epoll_create1(0);
if (epfd_ < 0) {
LOG(FATAL) << "SocketPool cannot create epfd";
}
#endif
}
void SocketPool::AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id,
int events) {
int fd = socket->Socket();
tcp_sockets_[fd] = socket;
socket_ids_[fd] = socket_id;
#ifdef USE_EPOLL
epoll_event e;
e.data.fd = fd;
if (events == READ) {
e.events = EPOLLIN;
} else if (events == WRITE) {
e.events = EPOLLOUT;
} else if (events == READ + WRITE) {
e.events = EPOLLIN | EPOLLOUT;
}
if (epoll_ctl(epfd_, EPOLL_CTL_ADD, fd, &e) < 0) {
LOG(FATAL) << "SocketPool cannot add socket";
}
socket->SetNonBlocking(true);
#else
if (tcp_sockets_.size() > 1) {
LOG(FATAL) << "SocketPool supports only one socket if not use epoll."
"Please turn on USE_EPOLL on building";
}
#endif
}
size_t SocketPool::RemoveSocket(std::shared_ptr<TCPSocket> socket) {
int fd = socket->Socket();
socket_ids_.erase(fd);
tcp_sockets_.erase(fd);
#ifdef USE_EPOLL
epoll_ctl(epfd_, EPOLL_CTL_DEL, fd, NULL);
#endif
return socket_ids_.size();
}
SocketPool::~SocketPool() {
#ifdef USE_EPOLL
for (auto& id : socket_ids_) {
int fd = id.first;
epoll_ctl(epfd_, EPOLL_CTL_DEL, fd, NULL);
}
#endif
}
std::shared_ptr<TCPSocket> SocketPool::GetActiveSocket(int* socket_id) {
if (socket_ids_.empty()) {
return nullptr;
}
for (;;) {
while (pending_fds_.empty()) {
Wait();
}
int fd = pending_fds_.front();
pending_fds_.pop();
// Check if this socket is not removed
if (socket_ids_.find(fd) != socket_ids_.end()) {
*socket_id = socket_ids_[fd];
return tcp_sockets_[fd];
}
}
return nullptr;
}
void SocketPool::Wait() {
#ifdef USE_EPOLL
static const int MAX_EVENTS = 10;
epoll_event events[MAX_EVENTS];
int nfd = epoll_wait(epfd_, events, MAX_EVENTS, -1 /*Timeout*/);
for (int i = 0; i < nfd; ++i) {
pending_fds_.push(events[i].data.fd);
}
#else
pending_fds_.push(tcp_sockets_.begin()->second->Socket());
#endif
}
} // namespace network
} // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file socket_pool.h
* \brief Socket pool of nonblocking sockets for DGL distributed training.
*/
#ifndef DGL_RPC_NETWORK_SOCKET_POOL_H_
#define DGL_RPC_NETWORK_SOCKET_POOL_H_
#include <unordered_map>
#include <queue>
#include <memory>
namespace dgl {
namespace network {
class TCPSocket;
/*!
* \brief SocketPool maintains a group of nonblocking sockets, and can provide
* active sockets.
* Currently SocketPool is based on epoll, a scalable I/O event notification
* mechanism in Linux operating system.
*/
class SocketPool {
public:
/*!
* \brief socket mode read/receive
*/
static const int READ = 1;
/*!
* \brief socket mode write/send
*/
static const int WRITE = 2;
/*!
* \brief SocketPool constructor
*/
SocketPool();
/*!
* \brief Add a socket to SocketPool
* \param socket tcp socket to add
* \param socket_id receiver/sender id of the socket
* \param events READ, WRITE or READ + WRITE
*/
void AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id,
int events = READ);
/*!
* \brief Remove socket from SocketPool
* \param socket tcp socket to remove
* \return number of remaing sockets in the pool
*/
size_t RemoveSocket(std::shared_ptr<TCPSocket> socket);
/*!
* \brief SocketPool destructor
*/
~SocketPool();
/*!
* \brief Get current active socket. This is a blocking method
* \param socket_id output parameter of the socket_id of active socket
* \return active TCPSocket
*/
std::shared_ptr<TCPSocket> GetActiveSocket(int* socket_id);
private:
/*!
* \brief Wait for event notification
*/
void Wait();
/*!
* \brief map from fd to TCPSocket
*/
std::unordered_map<int, std::shared_ptr<TCPSocket>> tcp_sockets_;
/*!
* \brief map from fd to socket_id
*/
std::unordered_map<int, int> socket_ids_;
/*!
* \brief fd for epoll base
*/
int epfd_;
/*!
* \brief queue for current active fds
*/
std::queue<int> pending_fds_;
};
} // namespace network
} // namespace dgl
#endif // DGL_RPC_NETWORK_SOCKET_POOL_H_
...@@ -119,7 +119,7 @@ bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) { ...@@ -119,7 +119,7 @@ bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) {
} }
#ifdef _WIN32 #ifdef _WIN32
bool TCPSocket::SetBlocking(bool flag) { bool TCPSocket::SetNonBlocking(bool flag) {
int result; int result;
u_long argp = flag ? 1 : 0; u_long argp = flag ? 1 : 0;
...@@ -134,7 +134,7 @@ bool TCPSocket::SetBlocking(bool flag) { ...@@ -134,7 +134,7 @@ bool TCPSocket::SetBlocking(bool flag) {
return true; return true;
} }
#else // !_WIN32 #else // !_WIN32
bool TCPSocket::SetBlocking(bool flag) { bool TCPSocket::SetNonBlocking(bool flag) {
int opts; int opts;
if ((opts = fcntl(socket_, F_GETFL)) < 0) { if ((opts = fcntl(socket_, F_GETFL)) < 0) {
...@@ -205,7 +205,7 @@ int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) { ...@@ -205,7 +205,7 @@ int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) {
do { // retry if EINTR failure appears do { // retry if EINTR failure appears
number_recv = recv(socket_, buffer, size_buffer, 0); number_recv = recv(socket_, buffer, size_buffer, 0);
} while (number_recv == -1 && errno == EINTR); } while (number_recv == -1 && errno == EINTR);
if (number_recv == -1) { if (number_recv == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {
LOG(ERROR) << "recv error: " << strerror(errno); LOG(ERROR) << "recv error: " << strerror(errno);
} }
......
...@@ -70,12 +70,12 @@ class TCPSocket { ...@@ -70,12 +70,12 @@ class TCPSocket {
int * port_client); int * port_client);
/*! /*!
* \brief SetBlocking() is needed refering to this example of epoll: * \brief SetNonBlocking() is needed refering to this example of epoll:
* http://www.kernel.org/doc/man-pages/online/pages/man4/epoll.4.html * http://www.kernel.org/doc/man-pages/online/pages/man4/epoll.4.html
* \param flag flag for blocking * \param flag true for nonblocking, false for blocking
* \return true for success and false for failure * \return true for success and false for failure
*/ */
bool SetBlocking(bool flag); bool SetNonBlocking(bool flag);
/*! /*!
* \brief Set timeout for socket * \brief Set timeout for socket
......
...@@ -87,8 +87,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender") ...@@ -87,8 +87,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0]; int64_t msg_queue_size = args[0];
std::string type = args[1]; std::string type = args[1];
int max_thread_count = args[2];
if (type.compare("socket") == 0) { if (type.compare("socket") == 0) {
RPCContext::ThreadLocal()->sender = std::make_shared<network::SocketSender>(msg_queue_size); RPCContext::ThreadLocal()->sender =
std::make_shared<network::SocketSender>(msg_queue_size, max_thread_count);
} else { } else {
LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type; LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type;
} }
...@@ -98,8 +100,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver") ...@@ -98,8 +100,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0]; int64_t msg_queue_size = args[0];
std::string type = args[1]; std::string type = args[1];
int max_thread_count = args[2];
if (type.compare("socket") == 0) { if (type.compare("socket") == 0) {
RPCContext::ThreadLocal()->receiver = std::make_shared<network::SocketReceiver>(msg_queue_size); RPCContext::ThreadLocal()->receiver =
std::make_shared<network::SocketReceiver>(msg_queue_size, max_thread_count);
} else { } else {
LOG(FATAL) << "Unknown communicator type for rpc sender: " << type; LOG(FATAL) << "Unknown communicator type for rpc sender: " << type;
} }
......
/*!
* Copyright (c) 2021 by Contributors
* \file semaphore_wrapper.cc
* \brief A simple corss platform semaphore wrapper
*/
#include "semaphore_wrapper.h"
#include <dmlc/logging.h>
namespace dgl {
namespace runtime {
#ifdef _WIN32
Semaphore::Semaphore() {
sem_ = CreateSemaphore(nullptr, 0, INT_MAX, nullptr);
if (!sem_) {
LOG(FATAL) << "Cannot create semaphore";
}
}
void Semaphore::Wait() {
WaitForSingleObject(sem_, INFINITE);
}
void Semaphore::Post() {
ReleaseSemaphore(sem_, 1, nullptr);
}
#else
Semaphore::Semaphore() {
sem_init(&sem_, 0, 0);
}
void Semaphore::Wait() {
sem_wait(&sem_);
}
void Semaphore::Post() {
sem_post(&sem_);
}
#endif
} // namespace runtime
} // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file semaphore_wrapper.h
* \brief A simple corss platform semaphore wrapper
*/
#ifndef DGL_RUNTIME_SEMAPHORE_WRAPPER_H_
#define DGL_RUNTIME_SEMAPHORE_WRAPPER_H_
#ifdef _WIN32
#include <windows.h>
#else
#include <semaphore.h>
#endif
namespace dgl {
namespace runtime {
/*!
* \brief A simple crossplatform Semaphore wrapper
*/
class Semaphore {
public:
/*!
* \brief Semaphore constructor
*/
Semaphore();
/*!
* \brief blocking wait, decrease semaphore by 1
*/
void Wait();
/*!
* \brief increase semaphore by 1
*/
void Post();
private:
#ifdef _WIN32
HANDLE sem_;
#else
sem_t sem_;
#endif
};
} // namespace runtime
} // namespace dgl
#endif // DGL_RUNTIME_SEMAPHORE_WRAPPER_H_
...@@ -25,6 +25,7 @@ using dgl::network::Message; ...@@ -25,6 +25,7 @@ using dgl::network::Message;
using dgl::network::DefaultMessageDeleter; using dgl::network::DefaultMessageDeleter;
const int64_t kQueueSize = 500 * 1024; const int64_t kQueueSize = 500 * 1024;
const int kThreadNum = 2;
#ifndef WIN32 #ifndef WIN32
...@@ -61,7 +62,7 @@ TEST(SocketCommunicatorTest, SendAndRecv) { ...@@ -61,7 +62,7 @@ TEST(SocketCommunicatorTest, SendAndRecv) {
} }
void start_client() { void start_client() {
SocketSender sender(kQueueSize); SocketSender sender(kQueueSize, kThreadNum);
for (int i = 0; i < kNumReceiver; ++i) { for (int i = 0; i < kNumReceiver; ++i) {
sender.AddReceiver(ip_addr[i], i); sender.AddReceiver(ip_addr[i], i);
} }
...@@ -89,7 +90,7 @@ void start_client() { ...@@ -89,7 +90,7 @@ void start_client() {
void start_server(int id) { void start_server(int id) {
sleep(5); sleep(5);
SocketReceiver receiver(kQueueSize); SocketReceiver receiver(kQueueSize, kThreadNum);
receiver.Wait(ip_addr[id], kNumSender); receiver.Wait(ip_addr[id], kNumSender);
for (int i = 0; i < kNumMessage; ++i) { for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < kNumSender; ++n) { for (int n = 0; n < kNumSender; ++n) {
...@@ -168,7 +169,7 @@ static void start_client() { ...@@ -168,7 +169,7 @@ static void start_client() {
std::string ip_addr((std::istreambuf_iterator<char>(t)), std::string ip_addr((std::istreambuf_iterator<char>(t)),
std::istreambuf_iterator<char>()); std::istreambuf_iterator<char>());
t.close(); t.close();
SocketSender sender(kQueueSize); SocketSender sender(kQueueSize, kThreadNum);
sender.AddReceiver(ip_addr.c_str(), 0); sender.AddReceiver(ip_addr.c_str(), 0);
sender.Connect(); sender.Connect();
char* str_data = new char[9]; char* str_data = new char[9];
...@@ -185,7 +186,7 @@ static bool start_server() { ...@@ -185,7 +186,7 @@ static bool start_server() {
std::string ip_addr((std::istreambuf_iterator<char>(t)), std::string ip_addr((std::istreambuf_iterator<char>(t)),
std::istreambuf_iterator<char>()); std::istreambuf_iterator<char>());
t.close(); t.close();
SocketReceiver receiver(kQueueSize); SocketReceiver receiver(kQueueSize, kThreadNum);
receiver.Wait(ip_addr.c_str(), 1); receiver.Wait(ip_addr.c_str(), 1);
Message msg; Message msg;
EXPECT_EQ(receiver.RecvFrom(&msg, 0), REMOVE_SUCCESS); EXPECT_EQ(receiver.RecvFrom(&msg, 0), REMOVE_SUCCESS);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment