Unverified Commit 5cf48fc6 authored by Jingcheng Yu's avatar Jingcheng Yu Committed by GitHub
Browse files

[Feature] Implement one thread multiple socket (#3200)


Co-authored-by: default avatarJingchengYu94 <jingchengyu94@gmail.com>
parent 179d6aab
......@@ -34,6 +34,7 @@ dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
dgl_option(LIBCXX_ENABLE_PARALLEL_ALGORITHMS "Enable the parallel algorithms library. This requires the PSTL to be available." OFF)
dgl_option(USE_S3 "Build with S3 support" OFF)
dgl_option(USE_HDFS "Build with HDFS support" OFF) # Set env HADOOP_HDFS_HOME if needed
dgl_option(USE_EPOLL "Build with epoll for socket communicator" OFF)
# Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG
if (NOT MSVC)
......@@ -120,6 +121,14 @@ if(USE_AVX)
endif(USE_LIBXSMM)
endif(USE_AVX)
if (USE_EPOLL)
check_include_file("sys/epoll.h" USE_EPOLL)
if (USE_EPOLL)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_EPOLL")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_EPOLL")
endif()
endif ()
# Build with fp16 to support mixed precision training.
if(USE_FP16)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_FP16")
......
"""RPC components. They are typically functions or utilities used by both
server and clients."""
import os
import abc
import pickle
import random
......@@ -111,7 +112,8 @@ def create_sender(max_queue_size, net_type):
net_type : str
Networking type. Current options are: 'socket'.
"""
_CAPI_DGLRPCCreateSender(int(max_queue_size), net_type)
max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0'))
_CAPI_DGLRPCCreateSender(int(max_queue_size), net_type, max_thread_count)
def create_receiver(max_queue_size, net_type):
"""Create rpc receiver of this process.
......@@ -123,7 +125,8 @@ def create_receiver(max_queue_size, net_type):
net_type : str
Networking type. Current options are: 'socket'.
"""
_CAPI_DGLRPCCreateReceiver(int(max_queue_size), net_type)
max_thread_count = int(os.getenv('DGL_SOCKET_MAX_THREAD_COUNT', '0'))
_CAPI_DGLRPCCreateReceiver(int(max_queue_size), net_type, max_thread_count)
def finalize_sender():
"""Finalize rpc sender of this process.
......
......@@ -206,7 +206,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLSenderCreate")
int64_t msg_queue_size = args[1];
network::Sender* sender = nullptr;
if (type == "socket") {
sender = new network::SocketSender(msg_queue_size);
sender = new network::SocketSender(msg_queue_size, 0);
} else {
LOG(FATAL) << "Unknown communicator type: " << type;
}
......@@ -220,7 +220,7 @@ DGL_REGISTER_GLOBAL("network._CAPI_DGLReceiverCreate")
int64_t msg_queue_size = args[1];
network::Receiver* receiver = nullptr;
if (type == "socket") {
receiver = new network::SocketReceiver(msg_queue_size);
receiver = new network::SocketReceiver(msg_queue_size, 0);
} else {
LOG(FATAL) << "Unknown communicator type: " << type;
}
......
......@@ -28,11 +28,14 @@ class Sender {
/*!
* \brief Sender constructor
* \param queue_size size (bytes) of message queue.
* \param max_thread_count size of thread pool. 0 for no limit
* Note that, the queue_size parameter is optional.
*/
explicit Sender(int64_t queue_size = 0) {
explicit Sender(int64_t queue_size = 0, int max_thread_count = 0) {
CHECK_GE(queue_size, 0);
CHECK_GE(max_thread_count, 0);
queue_size_ = queue_size;
max_thread_count_ = max_thread_count;
}
virtual ~Sender() {}
......@@ -86,6 +89,10 @@ class Sender {
* \brief Size of message queue
*/
int64_t queue_size_;
/*!
* \brief Size of thread pool. 0 for no limit
*/
int max_thread_count_;
};
/*!
......@@ -101,13 +108,16 @@ class Receiver {
/*!
* \brief Receiver constructor
* \param queue_size size of message queue.
* \param max_thread_count size of thread pool. 0 for no limit
* Note that, the queue_size parameter is optional.
*/
explicit Receiver(int64_t queue_size = 0) {
explicit Receiver(int64_t queue_size = 0, int max_thread_count = 0) {
if (queue_size < 0) {
LOG(FATAL) << "queue_size cannot be a negative number.";
}
CHECK_GE(max_thread_count, 0);
queue_size_ = queue_size;
max_thread_count_ = max_thread_count;
}
virtual ~Receiver() {}
......@@ -165,6 +175,10 @@ class Receiver {
* \brief Size of message queue
*/
int64_t queue_size_;
/*!
* \brief Size of thread pool. 0 for no limit
*/
int max_thread_count_;
};
} // namespace network
......
......@@ -72,6 +72,7 @@ STATUS MessageQueue::Remove(Message* msg, bool is_blocking) {
queue_.pop();
msg->data = old_msg.data;
msg->size = old_msg.size;
msg->receiver_id = old_msg.receiver_id;
msg->deallocator = old_msg.deallocator;
free_size_ += old_msg.size;
cond_not_full_.notify_one();
......
......@@ -56,6 +56,10 @@ struct Message {
* \brief message size in bytes
*/
int64_t size;
/*!
* \brief message receiver id
*/
int receiver_id = -1;
/*!
* \brief user-defined deallocator, which can be nullptr
*/
......
......@@ -12,6 +12,7 @@
#include "socket_communicator.h"
#include "../../c_api_common.h"
#include "socket_pool.h"
#ifdef _WIN32
#include <windows.h>
......@@ -51,15 +52,20 @@ void SocketSender::AddReceiver(const char* addr, int recv_id) {
address.ip = ip_and_port[0];
address.port = std::stoi(ip_and_port[1]);
receiver_addrs_[recv_id] = address;
msg_queue_[recv_id] = std::make_shared<MessageQueue>(queue_size_);
}
bool SocketSender::Connect() {
// Create N sockets for Receiver
int receiver_count = static_cast<int>(receiver_addrs_.size());
if (max_thread_count_ == 0 || max_thread_count_ > receiver_count) {
max_thread_count_ = receiver_count;
}
sockets_.resize(max_thread_count_);
for (const auto& r : receiver_addrs_) {
int ID = r.first;
sockets_[ID] = std::make_shared<TCPSocket>();
TCPSocket* client_socket = sockets_[ID].get();
int receiver_id = r.first;
int thread_id = receiver_id % max_thread_count_;
sockets_[thread_id][receiver_id] = std::make_shared<TCPSocket>();
TCPSocket* client_socket = sockets_[thread_id][receiver_id].get();
bool bo = false;
int try_count = 0;
const char* ip = r.second.ip.c_str();
......@@ -83,12 +89,17 @@ bool SocketSender::Connect() {
if (bo == false) {
return bo;
}
}
for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
msg_queue_.push_back(std::make_shared<MessageQueue>(queue_size_));
// Create a new thread for this socket connection
threads_[ID] = std::make_shared<std::thread>(
threads_.push_back(std::make_shared<std::thread>(
SendLoop,
client_socket,
msg_queue_[ID].get());
sockets_[thread_id],
msg_queue_[thread_id]));
}
return true;
}
......@@ -96,53 +107,48 @@ STATUS SocketSender::Send(Message msg, int recv_id) {
CHECK_NOTNULL(msg.data);
CHECK_GT(msg.size, 0);
CHECK_GE(recv_id, 0);
msg.receiver_id = recv_id;
// Add data message to message queue
STATUS code = msg_queue_[recv_id]->Add(msg);
STATUS code = msg_queue_[recv_id % max_thread_count_]->Add(msg);
return code;
}
void SocketSender::Finalize() {
// Send a signal to tell the msg_queue to finish its job
for (auto& mq : msg_queue_) {
for (int i = 0; i < max_thread_count_; ++i) {
// wait until queue is empty
while (mq.second->Empty() == false) {
auto& mq = msg_queue_[i];
while (mq->Empty() == false) {
#ifdef _WIN32
// just loop
#else // !_WIN32
usleep(1000);
#endif // _WIN32
}
int ID = mq.first;
mq.second->SignalFinished(ID);
// All queues have only one producer, which is main thread, so
// the producerID argument here should be zero.
mq->SignalFinished(0);
}
// Block main thread until all socket-threads finish their jobs
for (auto& thread : threads_) {
thread.second->join();
thread->join();
}
// Clear all sockets
for (auto& socket : sockets_) {
for (auto& group_sockets_ : sockets_) {
for (auto &socket : group_sockets_) {
socket.second->Close();
}
}
}
void SocketSender::SendLoop(TCPSocket* socket, MessageQueue* queue) {
CHECK_NOTNULL(socket);
CHECK_NOTNULL(queue);
bool exit = false;
while (!exit) {
Message msg;
STATUS code = queue->Remove(&msg);
if (code == QUEUE_CLOSE) {
msg.size = 0; // send an end-signal to receiver
exit = true;
}
void SendCore(Message msg, TCPSocket* socket) {
// First send the size
// If exit == true, we will send zero size to reciever
int64_t sent_bytes = 0;
while (static_cast<size_t>(sent_bytes) < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - sent_bytes;
int64_t tmp = socket->Send(
reinterpret_cast<char*>(&msg.size)+sent_bytes,
reinterpret_cast<char*>(&msg.size) + sent_bytes,
max_len);
CHECK_NE(tmp, -1);
sent_bytes += tmp;
......@@ -159,6 +165,22 @@ void SocketSender::SendLoop(TCPSocket* socket, MessageQueue* queue) {
if (msg.deallocator != nullptr) {
msg.deallocator(&msg);
}
}
void SocketSender::SendLoop(
std::unordered_map<int, std::shared_ptr<TCPSocket>> sockets,
std::shared_ptr<MessageQueue> queue) {
for (;;) {
Message msg;
STATUS code = queue->Remove(&msg);
if (code == QUEUE_CLOSE) {
msg.size = 0; // send an end-signal to receiver
for (auto& socket : sockets) {
SendCore(msg, socket.second.get());
}
break;
}
SendCore(msg, sockets[msg.receiver_id].get());
}
}
......@@ -187,16 +209,20 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) {
int port = stoi(ip_and_port[1]);
// Initialize message queue for each connection
num_sender_ = num_sender;
for (int i = 0; i < num_sender_; ++i) {
msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
#ifdef USE_EPOLL
if (max_thread_count_ == 0 || max_thread_count_ > num_sender_) {
max_thread_count_ = num_sender_;
}
mq_iter_ = msg_queue_.begin();
#else
max_thread_count_ = num_sender_;
#endif
// Initialize socket and socket-thread
server_socket_ = new TCPSocket();
// Bind socket
if (server_socket_->Bind(ip.c_str(), port) == false) {
LOG(FATAL) << "Cannot bind to " << ip << ":" << port;
}
// Listen
if (server_socket_->Listen(kMaxConnection) == false) {
LOG(FATAL) << "Cannot listen on " << ip << ":" << port;
......@@ -204,27 +230,39 @@ bool SocketReceiver::Wait(const char* addr, int num_sender) {
// Accept all sender sockets
std::string accept_ip;
int accept_port;
sockets_.resize(max_thread_count_);
for (int i = 0; i < num_sender_; ++i) {
sockets_[i] = std::make_shared<TCPSocket>();
if (server_socket_->Accept(sockets_[i].get(), &accept_ip, &accept_port) == false) {
int thread_id = i % max_thread_count_;
auto socket = std::make_shared<TCPSocket>();
sockets_[thread_id][i] = socket;
msg_queue_[i] = std::make_shared<MessageQueue>(queue_size_);
if (server_socket_->Accept(socket.get(), &accept_ip, &accept_port) == false) {
LOG(WARNING) << "Error on accept socket.";
return false;
}
}
mq_iter_ = msg_queue_.begin();
for (int thread_id = 0; thread_id < max_thread_count_; ++thread_id) {
// create new thread for each socket
threads_[i] = std::make_shared<std::thread>(
threads_.push_back(std::make_shared<std::thread>(
RecvLoop,
sockets_[i].get(),
msg_queue_[i].get());
sockets_[thread_id],
msg_queue_,
&queue_sem_));
}
return true;
}
STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
// loop until get a message
// queue_sem_ is a semaphore indicating how many elements in multiple
// message queues.
// When calling queue_sem_.Wait(), this Recv will be suspended until
// queue_sem_ > 0, decrease queue_sem_ by 1, then start to fetch a message.
queue_sem_.Wait();
for (;;) {
for (; mq_iter_ != msg_queue_.end(); ++mq_iter_) {
// We use non-block remove here
STATUS code = mq_iter_->second->Remove(msg, false);
if (code == QUEUE_EMPTY) {
continue; // jump to the next queue
......@@ -240,6 +278,7 @@ STATUS SocketReceiver::Recv(Message* msg, int* send_id) {
STATUS SocketReceiver::RecvFrom(Message* msg, int send_id) {
// Get message from specified message queue
queue_sem_.Wait();
STATUS code = msg_queue_[send_id]->Remove(msg);
return code;
}
......@@ -255,47 +294,93 @@ void SocketReceiver::Finalize() {
usleep(1000);
#endif // _WIN32
}
int ID = mq.first;
mq.second->SignalFinished(ID);
mq.second->SignalFinished(mq.first);
}
// Block main thread until all socket-threads finish their jobs
for (auto& thread : threads_) {
thread.second->join();
thread->join();
}
// Clear all sockets
for (auto& socket : sockets_) {
for (auto& group_sockets : sockets_) {
for (auto& socket : group_sockets) {
socket.second->Close();
}
}
server_socket_->Close();
delete server_socket_;
}
void SocketReceiver::RecvLoop(TCPSocket* socket, MessageQueue* queue) {
CHECK_NOTNULL(socket);
CHECK_NOTNULL(queue);
for (;;) {
// If main thread had finished its job
if (queue->EmptyAndNoMoreAdd()) {
return; // exit loop thread
}
// First recv the size
int64_t RecvDataSize(TCPSocket* socket) {
int64_t received_bytes = 0;
int64_t data_size = 0;
while (static_cast<size_t>(received_bytes) < sizeof(int64_t)) {
int64_t max_len = sizeof(int64_t) - received_bytes;
int64_t tmp = socket->Receive(
reinterpret_cast<char*>(&data_size)+received_bytes,
reinterpret_cast<char*>(&data_size) + received_bytes,
max_len);
CHECK_NE(tmp, -1);
if (tmp == -1) {
if (received_bytes > 0) {
// We want to finish reading full data_size
continue;
}
return -1;
}
received_bytes += tmp;
}
if (data_size < 0) {
LOG(FATAL) << "Recv data error (data_size: " << data_size << ")";
} else if (data_size == 0) {
// This is an end-signal sent by client
return data_size;
}
void RecvData(TCPSocket* socket, char* buffer, const int64_t &data_size,
int64_t *received_bytes) {
while (*received_bytes < data_size) {
int64_t max_len = data_size - *received_bytes;
int64_t tmp = socket->Receive(buffer + *received_bytes, max_len);
if (tmp == -1) {
// Socket not ready, no more data to read
return;
} else {
char* buffer = nullptr;
}
*received_bytes += tmp;
}
}
void SocketReceiver::RecvLoop(
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<TCPSocket>> sockets,
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<MessageQueue>> queues,
runtime::Semaphore *queue_sem) {
std::unordered_map<int, std::unique_ptr<RecvContext>> recv_contexts;
SocketPool socket_pool;
for (auto& socket : sockets) {
auto &sender_id = socket.first;
socket_pool.AddSocket(socket.second, sender_id);
recv_contexts[sender_id] = std::unique_ptr<RecvContext>(new RecvContext());
}
// Main loop to receive messages
for (;;) {
int sender_id;
// Get active socket using epoll
std::shared_ptr<TCPSocket> socket = socket_pool.GetActiveSocket(&sender_id);
if (queues[sender_id]->EmptyAndNoMoreAdd()) {
// This sender has already stopped
if (socket_pool.RemoveSocket(socket) == 0) {
return;
}
continue;
}
// Nonblocking socket might be interrupted at any point. So we need to
// store the partially received data
std::unique_ptr<RecvContext> &ctx = recv_contexts[sender_id];
int64_t &data_size = ctx->data_size;
int64_t &received_bytes = ctx->received_bytes;
char*& buffer = ctx->buffer;
if (data_size == -1) {
// This is a new message, so receive the data size first
data_size = RecvDataSize(socket.get());
if (data_size > 0) {
try {
buffer = new char[data_size];
} catch(const std::bad_alloc&) {
......@@ -303,17 +388,28 @@ void SocketReceiver::RecvLoop(TCPSocket* socket, MessageQueue* queue) {
<< "(message size: " << data_size << ")";
}
received_bytes = 0;
while (received_bytes < data_size) {
int64_t max_len = data_size - received_bytes;
int64_t tmp = socket->Receive(buffer+received_bytes, max_len);
CHECK_NE(tmp, -1);
received_bytes += tmp;
} else if (data_size == 0) {
// Received stop signal
if (socket_pool.RemoveSocket(socket) == 0) {
return;
}
}
}
RecvData(socket.get(), buffer, data_size, &received_bytes);
if (received_bytes >= data_size) {
// Full data received, create Message and push to queue
Message msg;
msg.data = buffer;
msg.size = data_size;
msg.deallocator = DefaultMessageDeleter;
queue->Add(msg);
queues[sender_id]->Add(msg);
// Reset recv context
data_size = -1;
// Signal queue semaphore
queue_sem->Post();
}
}
}
......
......@@ -12,6 +12,7 @@
#include <unordered_map>
#include <memory>
#include "../../runtime/semaphore_wrapper.h"
#include "communicator.h"
#include "msg_queue.h"
#include "tcp_socket.h"
......@@ -42,8 +43,10 @@ class SocketSender : public Sender {
/*!
* \brief Sender constructor
* \param queue_size size of message queue
* \param max_thread_count size of thread pool. 0 for no limit
*/
explicit SocketSender(int64_t queue_size) : Sender(queue_size) {}
SocketSender(int64_t queue_size, int max_thread_count)
: Sender(queue_size, max_thread_count) {}
/*!
* \brief Add receiver's address and ID to the sender's namebook
......@@ -93,7 +96,8 @@ class SocketSender : public Sender {
/*!
* \brief socket for each connection of receiver
*/
std::unordered_map<int /* receiver ID */, std::shared_ptr<TCPSocket>> sockets_;
std::vector<std::unordered_map<int /* receiver ID */,
std::shared_ptr<TCPSocket>>> sockets_;
/*!
* \brief receivers' address
......@@ -101,24 +105,27 @@ class SocketSender : public Sender {
std::unordered_map<int /* receiver ID */, IPAddr> receiver_addrs_;
/*!
* \brief message queue for each socket connection
* \brief message queue for each thread
*/
std::unordered_map<int /* receiver ID */, std::shared_ptr<MessageQueue>> msg_queue_;
std::vector<std::shared_ptr<MessageQueue>> msg_queue_;
/*!
* \brief Independent thread for each socket connection
* \brief Independent thread
*/
std::unordered_map<int /* receiver ID */, std::shared_ptr<std::thread>> threads_;
std::vector<std::shared_ptr<std::thread>> threads_;
/*!
* \brief Send-loop for each socket in per-thread
* \param socket TCPSocket for current connection
* \param queue message_queue for current connection
* \brief Send-loop for each thread
* \param sockets TCPSockets for current thread
* \param queue message_queue for current thread
*
* Note that, the SendLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue.
*/
static void SendLoop(TCPSocket* socket, MessageQueue* queue);
static void SendLoop(
std::unordered_map<int /* Receiver (virtual) ID */,
std::shared_ptr<TCPSocket>> sockets,
std::shared_ptr<MessageQueue> queue);
};
/*!
......@@ -131,8 +138,10 @@ class SocketReceiver : public Receiver {
/*!
* \brief Receiver constructor
* \param queue_size size of message queue.
* \param max_thread_count size of thread pool. 0 for no limit
*/
explicit SocketReceiver(int64_t queue_size) : Receiver(queue_size) {}
SocketReceiver(int64_t queue_size, int max_thread_count)
: Receiver(queue_size, max_thread_count) {}
/*!
* \brief Wait for all the Senders to connect
......@@ -183,6 +192,11 @@ class SocketReceiver : public Receiver {
inline std::string Type() const { return std::string("socket"); }
private:
struct RecvContext {
int64_t data_size = -1;
int64_t received_bytes = 0;
char *buffer = nullptr;
};
/*!
* \brief number of sender
*/
......@@ -196,28 +210,41 @@ class SocketReceiver : public Receiver {
/*!
* \brief socket for each client connections
*/
std::unordered_map<int /* Sender (virutal) ID */, std::shared_ptr<TCPSocket>> sockets_;
std::vector<std::unordered_map<int /* Sender (virutal) ID */,
std::shared_ptr<TCPSocket>>> sockets_;
/*!
* \brief Message queue for each socket connection
*/
std::unordered_map<int /* Sender (virtual) ID */, std::shared_ptr<MessageQueue>> msg_queue_;
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<MessageQueue>> msg_queue_;
std::unordered_map<int, std::shared_ptr<MessageQueue>>::iterator mq_iter_;
/*!
* \brief Independent thead for each socket connection
* \brief Independent thead
*/
std::unordered_map<int /* Sender (virtual) ID */, std::shared_ptr<std::thread>> threads_;
std::vector<std::shared_ptr<std::thread>> threads_;
/*!
* \brief Recv-loop for each socket in per-thread
* \param socket client socket
* \param queue message queue
* \brief queue_sem_ semphore to indicate number of messages in multiple
* message queues to prevent busy wait of Recv
*/
runtime::Semaphore queue_sem_;
/*!
* \brief Recv-loop for each thread
* \param sockets client sockets of current thread
* \param queue message queues of current thread
*
* Note that, the RecvLoop will finish its loop-job and exit thread
* when the main thread invokes Signal() API on the message queue.
*/
static void RecvLoop(TCPSocket* socket, MessageQueue* queue);
static void RecvLoop(
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<TCPSocket>> sockets,
std::unordered_map<int /* Sender (virtual) ID */,
std::shared_ptr<MessageQueue>> queues,
runtime::Semaphore *queue_sem);
};
} // namespace network
......
/*!
* Copyright (c) 2021 by Contributors
* \file socket_pool.cc
* \brief Socket pool of nonblocking sockets for DGL distributed training.
*/
#include "socket_pool.h"
#include <dmlc/logging.h>
#include "tcp_socket.h"
#ifdef USE_EPOLL
#include <sys/epoll.h>
#endif
namespace dgl {
namespace network {
SocketPool::SocketPool() {
#ifdef USE_EPOLL
epfd_ = epoll_create1(0);
if (epfd_ < 0) {
LOG(FATAL) << "SocketPool cannot create epfd";
}
#endif
}
void SocketPool::AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id,
int events) {
int fd = socket->Socket();
tcp_sockets_[fd] = socket;
socket_ids_[fd] = socket_id;
#ifdef USE_EPOLL
epoll_event e;
e.data.fd = fd;
if (events == READ) {
e.events = EPOLLIN;
} else if (events == WRITE) {
e.events = EPOLLOUT;
} else if (events == READ + WRITE) {
e.events = EPOLLIN | EPOLLOUT;
}
if (epoll_ctl(epfd_, EPOLL_CTL_ADD, fd, &e) < 0) {
LOG(FATAL) << "SocketPool cannot add socket";
}
socket->SetNonBlocking(true);
#else
if (tcp_sockets_.size() > 1) {
LOG(FATAL) << "SocketPool supports only one socket if not use epoll."
"Please turn on USE_EPOLL on building";
}
#endif
}
size_t SocketPool::RemoveSocket(std::shared_ptr<TCPSocket> socket) {
int fd = socket->Socket();
socket_ids_.erase(fd);
tcp_sockets_.erase(fd);
#ifdef USE_EPOLL
epoll_ctl(epfd_, EPOLL_CTL_DEL, fd, NULL);
#endif
return socket_ids_.size();
}
SocketPool::~SocketPool() {
#ifdef USE_EPOLL
for (auto& id : socket_ids_) {
int fd = id.first;
epoll_ctl(epfd_, EPOLL_CTL_DEL, fd, NULL);
}
#endif
}
std::shared_ptr<TCPSocket> SocketPool::GetActiveSocket(int* socket_id) {
if (socket_ids_.empty()) {
return nullptr;
}
for (;;) {
while (pending_fds_.empty()) {
Wait();
}
int fd = pending_fds_.front();
pending_fds_.pop();
// Check if this socket is not removed
if (socket_ids_.find(fd) != socket_ids_.end()) {
*socket_id = socket_ids_[fd];
return tcp_sockets_[fd];
}
}
return nullptr;
}
void SocketPool::Wait() {
#ifdef USE_EPOLL
static const int MAX_EVENTS = 10;
epoll_event events[MAX_EVENTS];
int nfd = epoll_wait(epfd_, events, MAX_EVENTS, -1 /*Timeout*/);
for (int i = 0; i < nfd; ++i) {
pending_fds_.push(events[i].data.fd);
}
#else
pending_fds_.push(tcp_sockets_.begin()->second->Socket());
#endif
}
} // namespace network
} // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file socket_pool.h
* \brief Socket pool of nonblocking sockets for DGL distributed training.
*/
#ifndef DGL_RPC_NETWORK_SOCKET_POOL_H_
#define DGL_RPC_NETWORK_SOCKET_POOL_H_
#include <unordered_map>
#include <queue>
#include <memory>
namespace dgl {
namespace network {
class TCPSocket;
/*!
* \brief SocketPool maintains a group of nonblocking sockets, and can provide
* active sockets.
* Currently SocketPool is based on epoll, a scalable I/O event notification
* mechanism in Linux operating system.
*/
class SocketPool {
public:
/*!
* \brief socket mode read/receive
*/
static const int READ = 1;
/*!
* \brief socket mode write/send
*/
static const int WRITE = 2;
/*!
* \brief SocketPool constructor
*/
SocketPool();
/*!
* \brief Add a socket to SocketPool
* \param socket tcp socket to add
* \param socket_id receiver/sender id of the socket
* \param events READ, WRITE or READ + WRITE
*/
void AddSocket(std::shared_ptr<TCPSocket> socket, int socket_id,
int events = READ);
/*!
* \brief Remove socket from SocketPool
* \param socket tcp socket to remove
* \return number of remaing sockets in the pool
*/
size_t RemoveSocket(std::shared_ptr<TCPSocket> socket);
/*!
* \brief SocketPool destructor
*/
~SocketPool();
/*!
* \brief Get current active socket. This is a blocking method
* \param socket_id output parameter of the socket_id of active socket
* \return active TCPSocket
*/
std::shared_ptr<TCPSocket> GetActiveSocket(int* socket_id);
private:
/*!
* \brief Wait for event notification
*/
void Wait();
/*!
* \brief map from fd to TCPSocket
*/
std::unordered_map<int, std::shared_ptr<TCPSocket>> tcp_sockets_;
/*!
* \brief map from fd to socket_id
*/
std::unordered_map<int, int> socket_ids_;
/*!
* \brief fd for epoll base
*/
int epfd_;
/*!
* \brief queue for current active fds
*/
std::queue<int> pending_fds_;
};
} // namespace network
} // namespace dgl
#endif // DGL_RPC_NETWORK_SOCKET_POOL_H_
......@@ -119,7 +119,7 @@ bool TCPSocket::Accept(TCPSocket * socket, std::string * ip, int * port) {
}
#ifdef _WIN32
bool TCPSocket::SetBlocking(bool flag) {
bool TCPSocket::SetNonBlocking(bool flag) {
int result;
u_long argp = flag ? 1 : 0;
......@@ -134,7 +134,7 @@ bool TCPSocket::SetBlocking(bool flag) {
return true;
}
#else // !_WIN32
bool TCPSocket::SetBlocking(bool flag) {
bool TCPSocket::SetNonBlocking(bool flag) {
int opts;
if ((opts = fcntl(socket_, F_GETFL)) < 0) {
......@@ -205,7 +205,7 @@ int64_t TCPSocket::Receive(char * buffer, int64_t size_buffer) {
do { // retry if EINTR failure appears
number_recv = recv(socket_, buffer, size_buffer, 0);
} while (number_recv == -1 && errno == EINTR);
if (number_recv == -1) {
if (number_recv == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {
LOG(ERROR) << "recv error: " << strerror(errno);
}
......
......@@ -70,12 +70,12 @@ class TCPSocket {
int * port_client);
/*!
* \brief SetBlocking() is needed refering to this example of epoll:
* \brief SetNonBlocking() is needed refering to this example of epoll:
* http://www.kernel.org/doc/man-pages/online/pages/man4/epoll.4.html
* \param flag flag for blocking
* \param flag true for nonblocking, false for blocking
* \return true for success and false for failure
*/
bool SetBlocking(bool flag);
bool SetNonBlocking(bool flag);
/*!
* \brief Set timeout for socket
......
......@@ -87,8 +87,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0];
std::string type = args[1];
int max_thread_count = args[2];
if (type.compare("socket") == 0) {
RPCContext::ThreadLocal()->sender = std::make_shared<network::SocketSender>(msg_queue_size);
RPCContext::ThreadLocal()->sender =
std::make_shared<network::SocketSender>(msg_queue_size, max_thread_count);
} else {
LOG(FATAL) << "Unknown communicator type for rpc receiver: " << type;
}
......@@ -98,8 +100,10 @@ DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateReceiver")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t msg_queue_size = args[0];
std::string type = args[1];
int max_thread_count = args[2];
if (type.compare("socket") == 0) {
RPCContext::ThreadLocal()->receiver = std::make_shared<network::SocketReceiver>(msg_queue_size);
RPCContext::ThreadLocal()->receiver =
std::make_shared<network::SocketReceiver>(msg_queue_size, max_thread_count);
} else {
LOG(FATAL) << "Unknown communicator type for rpc sender: " << type;
}
......
/*!
* Copyright (c) 2021 by Contributors
* \file semaphore_wrapper.cc
* \brief A simple corss platform semaphore wrapper
*/
#include "semaphore_wrapper.h"
#include <dmlc/logging.h>
namespace dgl {
namespace runtime {
#ifdef _WIN32
Semaphore::Semaphore() {
sem_ = CreateSemaphore(nullptr, 0, INT_MAX, nullptr);
if (!sem_) {
LOG(FATAL) << "Cannot create semaphore";
}
}
void Semaphore::Wait() {
WaitForSingleObject(sem_, INFINITE);
}
void Semaphore::Post() {
ReleaseSemaphore(sem_, 1, nullptr);
}
#else
Semaphore::Semaphore() {
sem_init(&sem_, 0, 0);
}
void Semaphore::Wait() {
sem_wait(&sem_);
}
void Semaphore::Post() {
sem_post(&sem_);
}
#endif
} // namespace runtime
} // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file semaphore_wrapper.h
* \brief A simple corss platform semaphore wrapper
*/
#ifndef DGL_RUNTIME_SEMAPHORE_WRAPPER_H_
#define DGL_RUNTIME_SEMAPHORE_WRAPPER_H_
#ifdef _WIN32
#include <windows.h>
#else
#include <semaphore.h>
#endif
namespace dgl {
namespace runtime {
/*!
* \brief A simple crossplatform Semaphore wrapper
*/
class Semaphore {
public:
/*!
* \brief Semaphore constructor
*/
Semaphore();
/*!
* \brief blocking wait, decrease semaphore by 1
*/
void Wait();
/*!
* \brief increase semaphore by 1
*/
void Post();
private:
#ifdef _WIN32
HANDLE sem_;
#else
sem_t sem_;
#endif
};
} // namespace runtime
} // namespace dgl
#endif // DGL_RUNTIME_SEMAPHORE_WRAPPER_H_
......@@ -25,6 +25,7 @@ using dgl::network::Message;
using dgl::network::DefaultMessageDeleter;
const int64_t kQueueSize = 500 * 1024;
const int kThreadNum = 2;
#ifndef WIN32
......@@ -61,7 +62,7 @@ TEST(SocketCommunicatorTest, SendAndRecv) {
}
void start_client() {
SocketSender sender(kQueueSize);
SocketSender sender(kQueueSize, kThreadNum);
for (int i = 0; i < kNumReceiver; ++i) {
sender.AddReceiver(ip_addr[i], i);
}
......@@ -89,7 +90,7 @@ void start_client() {
void start_server(int id) {
sleep(5);
SocketReceiver receiver(kQueueSize);
SocketReceiver receiver(kQueueSize, kThreadNum);
receiver.Wait(ip_addr[id], kNumSender);
for (int i = 0; i < kNumMessage; ++i) {
for (int n = 0; n < kNumSender; ++n) {
......@@ -168,7 +169,7 @@ static void start_client() {
std::string ip_addr((std::istreambuf_iterator<char>(t)),
std::istreambuf_iterator<char>());
t.close();
SocketSender sender(kQueueSize);
SocketSender sender(kQueueSize, kThreadNum);
sender.AddReceiver(ip_addr.c_str(), 0);
sender.Connect();
char* str_data = new char[9];
......@@ -185,7 +186,7 @@ static bool start_server() {
std::string ip_addr((std::istreambuf_iterator<char>(t)),
std::istreambuf_iterator<char>());
t.close();
SocketReceiver receiver(kQueueSize);
SocketReceiver receiver(kQueueSize, kThreadNum);
receiver.Wait(ip_addr.c_str(), 1);
Message msg;
EXPECT_EQ(receiver.RecvFrom(&msg, 0), REMOVE_SUCCESS);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment